mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 00:15:37 +01:00
Compare commits
20 Commits
batched-pa
...
multiport
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0496ef101e | ||
|
|
ae9de47dd9 | ||
|
|
4eb86afa54 | ||
|
|
f36db374ac | ||
|
|
7ac51c1af2 | ||
|
|
dabce8a1b4 | ||
|
|
6b78e9cdb3 | ||
|
|
b445d14ddb | ||
|
|
6606124bf9 | ||
|
|
05405bc261 | ||
|
|
b033267d6e | ||
|
|
659d7fece6 | ||
|
|
f2aef0d6eb | ||
|
|
a2b9747b0f | ||
|
|
0e593ad582 | ||
|
|
28ecfcbc03 | ||
|
|
e71059a410 | ||
|
|
aec7f5f865 | ||
|
|
6d8e939648 | ||
|
|
326fc8758d |
4
.github/workflows/gofmt.yml
vendored
4
.github/workflows/gofmt.yml
vendored
@@ -16,9 +16,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Install goimports
|
||||
|
||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -12,9 +12,9 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -35,9 +35,9 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -68,9 +68,9 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Import certificates
|
||||
|
||||
4
.github/workflows/smoke-extra.yml
vendored
4
.github/workflows/smoke-extra.yml
vendored
@@ -22,9 +22,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version-file: 'go.mod'
|
||||
check-latest: true
|
||||
|
||||
- name: add hashicorp source
|
||||
|
||||
12
.github/workflows/smoke.yml
vendored
12
.github/workflows/smoke.yml
vendored
@@ -20,9 +20,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: build
|
||||
@@ -52,4 +52,12 @@ jobs:
|
||||
working-directory: ./.github/workflows/smoke
|
||||
run: NAME="smoke-p256" ./smoke.sh
|
||||
|
||||
- name: setup docker image for multiport
|
||||
working-directory: ./.github/workflows/smoke
|
||||
run: NAME="smoke-multiport" MULTIPORT_TX=true MULTIPORT_RX=true MULTIPORT_HANDSHAKE=true ./build.sh
|
||||
|
||||
- name: run smoke
|
||||
working-directory: ./.github/workflows/smoke
|
||||
run: NAME="smoke-multiport" ./smoke.sh
|
||||
|
||||
timeout-minutes: 10
|
||||
|
||||
4
.github/workflows/smoke/genconfig.sh
vendored
4
.github/workflows/smoke/genconfig.sh
vendored
@@ -48,6 +48,10 @@ listen:
|
||||
|
||||
tun:
|
||||
dev: ${TUN_DEV:-tun0}
|
||||
multiport:
|
||||
tx_enabled: ${MULTIPORT_TX:-false}
|
||||
rx_enabled: ${MULTIPORT_RX:-false}
|
||||
tx_handshake: ${MULTIPORT_HANDSHAKE:-false}
|
||||
|
||||
firewall:
|
||||
inbound_action: reject
|
||||
|
||||
20
.github/workflows/test.yml
vendored
20
.github/workflows/test.yml
vendored
@@ -20,9 +20,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: v2.5
|
||||
version: v2.1
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
@@ -58,9 +58,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -79,9 +79,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.22'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -100,9 +100,9 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build nebula
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: v2.5
|
||||
version: v2.1
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
|
||||
4
Makefile
4
Makefile
@@ -227,6 +227,10 @@ smoke-relay-docker: bin-docker
|
||||
cd .github/workflows/smoke/ && ./build-relay.sh
|
||||
cd .github/workflows/smoke/ && ./smoke-relay.sh
|
||||
|
||||
smoke-multiport-docker: bin-docker
|
||||
cd .github/workflows/smoke/ && NAME="smoke-multiport" MULTIPORT_TX=true MULTIPORT_RX=true MULTIPORT_HANDSHAKE=true ./build.sh
|
||||
cd .github/workflows/smoke/ && NAME="smoke-multiport" ./smoke.sh
|
||||
|
||||
smoke-docker-race: BUILD_ARGS = -race
|
||||
smoke-docker-race: CGO_ENABLED = 1
|
||||
smoke-docker-race: smoke-docker
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
// batchPipelines tracks whether the inside device can operate on packet batches
|
||||
// and, if so, holds the shared packet pool sized for the virtio headroom and
|
||||
// payload limits advertised by the device. It also owns the fan-in/fan-out
|
||||
// queues between the TUN readers, encrypt/decrypt workers, and the UDP writers.
|
||||
type batchPipelines struct {
|
||||
enabled bool
|
||||
inside overlay.BatchCapableDevice
|
||||
headroom int
|
||||
payloadCap int
|
||||
pool *overlay.PacketPool
|
||||
batchSize int
|
||||
routines int
|
||||
rxQueues []chan *overlay.Packet
|
||||
txQueues []chan queuedDatagram
|
||||
tunQueues []chan *overlay.Packet
|
||||
}
|
||||
|
||||
type queuedDatagram struct {
|
||||
packet *overlay.Packet
|
||||
addr netip.AddrPort
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) init(device overlay.Device, routines int, queueDepth int, maxSegments int) {
|
||||
if device == nil || routines <= 0 {
|
||||
return
|
||||
}
|
||||
bcap, ok := device.(overlay.BatchCapableDevice)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
headroom := bcap.BatchHeadroom()
|
||||
payload := bcap.BatchPayloadCap()
|
||||
if maxSegments < 1 {
|
||||
maxSegments = 1
|
||||
}
|
||||
requiredPayload := udp.MTU * maxSegments
|
||||
if payload < requiredPayload {
|
||||
payload = requiredPayload
|
||||
}
|
||||
batchSize := bcap.BatchSize()
|
||||
if headroom <= 0 || payload <= 0 || batchSize <= 0 {
|
||||
return
|
||||
}
|
||||
bp.enabled = true
|
||||
bp.inside = bcap
|
||||
bp.headroom = headroom
|
||||
bp.payloadCap = payload
|
||||
bp.batchSize = batchSize
|
||||
bp.routines = routines
|
||||
bp.pool = overlay.NewPacketPool(headroom, payload)
|
||||
queueCap := batchSize * defaultBatchQueueDepthFactor
|
||||
if queueDepth > 0 {
|
||||
queueCap = queueDepth
|
||||
}
|
||||
if queueCap < batchSize {
|
||||
queueCap = batchSize
|
||||
}
|
||||
bp.rxQueues = make([]chan *overlay.Packet, routines)
|
||||
bp.txQueues = make([]chan queuedDatagram, routines)
|
||||
bp.tunQueues = make([]chan *overlay.Packet, routines)
|
||||
for i := 0; i < routines; i++ {
|
||||
bp.rxQueues[i] = make(chan *overlay.Packet, queueCap)
|
||||
bp.txQueues[i] = make(chan queuedDatagram, queueCap)
|
||||
bp.tunQueues[i] = make(chan *overlay.Packet, queueCap)
|
||||
}
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) Pool() *overlay.PacketPool {
|
||||
if bp == nil || !bp.enabled {
|
||||
return nil
|
||||
}
|
||||
return bp.pool
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) Enabled() bool {
|
||||
return bp != nil && bp.enabled
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) batchSizeHint() int {
|
||||
if bp == nil || bp.batchSize <= 0 {
|
||||
return 1
|
||||
}
|
||||
return bp.batchSize
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) rxQueue(i int) chan *overlay.Packet {
|
||||
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.rxQueues) {
|
||||
return nil
|
||||
}
|
||||
return bp.rxQueues[i]
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) txQueue(i int) chan queuedDatagram {
|
||||
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.txQueues) {
|
||||
return nil
|
||||
}
|
||||
return bp.txQueues[i]
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) tunQueue(i int) chan *overlay.Packet {
|
||||
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.tunQueues) {
|
||||
return nil
|
||||
}
|
||||
return bp.tunQueues[i]
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) txQueueLen(i int) int {
|
||||
q := bp.txQueue(i)
|
||||
if q == nil {
|
||||
return 0
|
||||
}
|
||||
return len(q)
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) tunQueueLen(i int) int {
|
||||
q := bp.tunQueue(i)
|
||||
if q == nil {
|
||||
return 0
|
||||
}
|
||||
return len(q)
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) enqueueRx(i int, pkt *overlay.Packet) bool {
|
||||
q := bp.rxQueue(i)
|
||||
if q == nil {
|
||||
return false
|
||||
}
|
||||
q <- pkt
|
||||
return true
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) enqueueTx(i int, pkt *overlay.Packet, addr netip.AddrPort) bool {
|
||||
q := bp.txQueue(i)
|
||||
if q == nil {
|
||||
return false
|
||||
}
|
||||
q <- queuedDatagram{packet: pkt, addr: addr}
|
||||
return true
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) enqueueTun(i int, pkt *overlay.Packet) bool {
|
||||
q := bp.tunQueue(i)
|
||||
if q == nil {
|
||||
return false
|
||||
}
|
||||
q <- pkt
|
||||
return true
|
||||
}
|
||||
|
||||
func (bp *batchPipelines) newPacket() *overlay.Packet {
|
||||
if bp == nil || !bp.enabled || bp.pool == nil {
|
||||
return nil
|
||||
}
|
||||
return bp.pool.Get()
|
||||
}
|
||||
@@ -84,11 +84,16 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
||||
|
||||
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
||||
|
||||
rawMap, ok := value.(map[string]any)
|
||||
rawMap, ok := value.(map[any]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||
}
|
||||
for rawCIDR, rawValue := range rawMap {
|
||||
for rawKey, rawValue := range rawMap {
|
||||
rawCIDR, ok := rawKey.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||
}
|
||||
|
||||
cidr, err := netip.ParsePrefix(rawCIDR)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||
@@ -124,7 +129,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
|
||||
}
|
||||
|
||||
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
||||
rawMap, ok := raw.(map[string]any)
|
||||
rawMap, ok := raw.(map[any]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid type: %T", raw)
|
||||
}
|
||||
|
||||
@@ -58,9 +58,6 @@ type Certificate interface {
|
||||
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
||||
PublicKey() []byte
|
||||
|
||||
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
|
||||
MarshalPublicKeyPEM() []byte
|
||||
|
||||
// Curve identifies which curve was used for the PublicKey and Signature.
|
||||
Curve() Curve
|
||||
|
||||
@@ -138,7 +135,8 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
|
||||
case Version2:
|
||||
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
|
||||
default:
|
||||
return nil, ErrUnknownVersion
|
||||
//TODO: CERT-V2 make a static var
|
||||
return nil, fmt.Errorf("unknown certificate version %d", v)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -83,10 +83,6 @@ func (c *certificateV1) PublicKey() []byte {
|
||||
return c.details.publicKey
|
||||
}
|
||||
|
||||
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
|
||||
return marshalCertPublicKeyToPEM(c)
|
||||
}
|
||||
|
||||
func (c *certificateV1) Signature() []byte {
|
||||
return c.signature
|
||||
}
|
||||
@@ -114,10 +110,8 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
|
||||
case Curve_CURVE25519:
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||
hashed := sha256.Sum256(b)
|
||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||
default:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
@@ -14,7 +13,6 @@ import (
|
||||
)
|
||||
|
||||
func TestCertificateV1_Marshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
@@ -62,58 +60,6 @@ func TestCertificateV1_Marshal(t *testing.T) {
|
||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||
}
|
||||
|
||||
func TestCertificateV1_PublicKeyPem(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{},
|
||||
unsafeNetworks: []netip.Prefix{},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
publicKey: pubKey,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedfghij1234567890ab",
|
||||
},
|
||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||
}
|
||||
|
||||
assert.Equal(t, Version1, nc.Version())
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.False(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = true
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
||||
require.NoError(t, err)
|
||||
nc.details.curve = Curve_P256
|
||||
nc.details.publicKey = pubP256Key
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = false
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.False(t, nc.IsCA())
|
||||
}
|
||||
|
||||
func TestCertificateV1_Expired(t *testing.T) {
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
|
||||
@@ -114,10 +114,6 @@ func (c *certificateV2) PublicKey() []byte {
|
||||
return c.publicKey
|
||||
}
|
||||
|
||||
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
|
||||
return marshalCertPublicKeyToPEM(c)
|
||||
}
|
||||
|
||||
func (c *certificateV2) Signature() []byte {
|
||||
return c.signature
|
||||
}
|
||||
@@ -153,10 +149,8 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
|
||||
case Curve_CURVE25519:
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||
hashed := sha256.Sum256(b)
|
||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||
default:
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
)
|
||||
|
||||
func TestCertificateV2_Marshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
@@ -76,58 +75,6 @@ func TestCertificateV2_Marshal(t *testing.T) {
|
||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||
}
|
||||
|
||||
func TestCertificateV2_PublicKeyPem(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{},
|
||||
unsafeNetworks: []netip.Prefix{},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedfghij1234567890ab",
|
||||
},
|
||||
publicKey: pubKey,
|
||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||
}
|
||||
|
||||
assert.Equal(t, Version2, nc.Version())
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.False(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = true
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
||||
require.NoError(t, err)
|
||||
nc.curve = Curve_P256
|
||||
nc.publicKey = pubP256Key
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = false
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.False(t, nc.IsCA())
|
||||
}
|
||||
|
||||
func TestCertificateV2_Expired(t *testing.T) {
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
|
||||
@@ -20,7 +20,6 @@ var (
|
||||
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
||||
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
||||
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
||||
ErrUnknownVersion = errors.New("certificate version unrecognized")
|
||||
|
||||
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
||||
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
||||
|
||||
149
cert/pem.go
149
cert/pem.go
@@ -1,34 +1,25 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
const ( //cert banners
|
||||
CertificateBanner = "NEBULA CERTIFICATE"
|
||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||
)
|
||||
const (
|
||||
CertificateBanner = "NEBULA CERTIFICATE"
|
||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||
|
||||
const ( //key-agreement-key banners
|
||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||
)
|
||||
|
||||
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
|
||||
const ( //signing key banners
|
||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
||||
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
||||
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
|
||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||
)
|
||||
|
||||
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
||||
@@ -60,16 +51,6 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
||||
|
||||
}
|
||||
|
||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
||||
if c.IsCA() {
|
||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||
} else {
|
||||
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
|
||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
||||
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
@@ -81,19 +62,6 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
|
||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
||||
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
|
||||
case Curve_P256:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
k, r := pem.Decode(b)
|
||||
if k == nil {
|
||||
@@ -105,7 +73,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
||||
expectedLen = 32
|
||||
curve = Curve_CURVE25519
|
||||
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
|
||||
case P256PublicKeyBanner:
|
||||
// Uncompressed
|
||||
expectedLen = 65
|
||||
curve = Curve_P256
|
||||
@@ -140,101 +108,6 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
// Backward compatibility functions for older API
|
||||
func MarshalX25519PublicKey(b []byte) []byte {
|
||||
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
|
||||
}
|
||||
|
||||
func MarshalX25519PrivateKey(b []byte) []byte {
|
||||
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
|
||||
}
|
||||
|
||||
func MarshalPublicKey(curve Curve, b []byte) []byte {
|
||||
return MarshalPublicKeyToPEM(curve, b)
|
||||
}
|
||||
|
||||
func MarshalPrivateKey(curve Curve, b []byte) []byte {
|
||||
return MarshalPrivateKeyToPEM(curve, b)
|
||||
}
|
||||
|
||||
// NebulaCertificate is a compatibility wrapper for the old API
|
||||
type NebulaCertificate struct {
|
||||
Details NebulaCertificateDetails
|
||||
Signature []byte
|
||||
cert Certificate
|
||||
}
|
||||
|
||||
// NebulaCertificateDetails is a compatibility wrapper for certificate details
|
||||
type NebulaCertificateDetails struct {
|
||||
Name string
|
||||
NotBefore time.Time
|
||||
NotAfter time.Time
|
||||
PublicKey []byte
|
||||
IsCA bool
|
||||
Issuer []byte
|
||||
Curve Curve
|
||||
}
|
||||
|
||||
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
|
||||
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
|
||||
c, rest, err := UnmarshalCertificateFromPEM(b)
|
||||
if err != nil {
|
||||
return nil, rest, err
|
||||
}
|
||||
|
||||
issuerBytes, err := func() ([]byte, error) {
|
||||
issuer := c.Issuer()
|
||||
if issuer == "" {
|
||||
return nil, nil
|
||||
}
|
||||
decoded, err := hex.DecodeString(issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
|
||||
}
|
||||
return decoded, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, rest, err
|
||||
}
|
||||
|
||||
pubKey := c.PublicKey()
|
||||
if pubKey != nil {
|
||||
pubKey = append([]byte(nil), pubKey...)
|
||||
}
|
||||
|
||||
sig := c.Signature()
|
||||
if sig != nil {
|
||||
sig = append([]byte(nil), sig...)
|
||||
}
|
||||
|
||||
return &NebulaCertificate{
|
||||
Details: NebulaCertificateDetails{
|
||||
Name: c.Name(),
|
||||
NotBefore: c.NotBefore(),
|
||||
NotAfter: c.NotAfter(),
|
||||
PublicKey: pubKey,
|
||||
IsCA: c.IsCA(),
|
||||
Issuer: issuerBytes,
|
||||
Curve: c.Curve(),
|
||||
},
|
||||
Signature: sig,
|
||||
cert: c,
|
||||
}, rest, nil
|
||||
}
|
||||
|
||||
// IssuerString returns the issuer in hex format for compatibility
|
||||
func (n *NebulaCertificate) IssuerString() string {
|
||||
if n.Details.Issuer == nil {
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(n.Details.Issuer)
|
||||
}
|
||||
|
||||
// Certificate returns the underlying certificate (read-only)
|
||||
func (n *NebulaCertificate) Certificate() Certificate {
|
||||
return n.cert
|
||||
}
|
||||
|
||||
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
||||
// consumed data or an error on failure
|
||||
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
|
||||
@@ -177,7 +177,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
}
|
||||
|
||||
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
||||
t.Parallel()
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
@@ -231,7 +230,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
}
|
||||
|
||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
@@ -242,12 +240,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
oldPubP256Key := []byte(`# A good key
|
||||
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA ECDSA P256 PUBLIC KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
@@ -264,22 +256,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-END NEBULA X25519 PUBLIC KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
|
||||
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||
assert.Len(t, k, 32)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Len(t, k, 65)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_P256, curve)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Len(t, k, 65)
|
||||
|
||||
12
cert/sign.go
12
cert/sign.go
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
@@ -54,10 +55,15 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
|
||||
}
|
||||
return t.SignWith(signer, curve, sp)
|
||||
case Curve_P256:
|
||||
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
pk := &ecdsa.PrivateKey{
|
||||
PublicKey: ecdsa.PublicKey{
|
||||
Curve: elliptic.P256(),
|
||||
},
|
||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
|
||||
D: new(big.Int).SetBytes(key),
|
||||
}
|
||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
|
||||
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
|
||||
sp := func(certBytes []byte) ([]byte, error) {
|
||||
// We need to hash first for ECDSA
|
||||
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
||||
|
||||
@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||
}
|
||||
|
||||
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
||||
nc := &cert.TBSCertificate{
|
||||
Version: v,
|
||||
Curve: c.Curve(),
|
||||
Name: c.Name(),
|
||||
Networks: c.Networks(),
|
||||
UnsafeNetworks: c.UnsafeNetworks(),
|
||||
Groups: c.Groups(),
|
||||
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
||||
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
||||
PublicKey: c.PublicKey(),
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, pem
|
||||
}
|
||||
|
||||
func X25519Keypair() ([]byte, []byte) {
|
||||
privkey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||
|
||||
@@ -354,8 +354,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
|
||||
if mainHostInfo {
|
||||
decision = tryRehandshake
|
||||
|
||||
} else {
|
||||
if cm.shouldSwapPrimary(hostinfo) {
|
||||
if cm.shouldSwapPrimary(hostinfo, primary) {
|
||||
decision = swapPrimary
|
||||
} else {
|
||||
// migrate the relays to the primary, if in use.
|
||||
@@ -446,7 +447,7 @@ func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time
|
||||
return inactiveDuration, true
|
||||
}
|
||||
|
||||
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
||||
func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
||||
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
|
||||
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
||||
// Let's sort this out.
|
||||
@@ -460,10 +461,6 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
||||
}
|
||||
|
||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||
if crt == nil {
|
||||
//my cert was reloaded away. We should definitely swap from this tunnel
|
||||
return true
|
||||
}
|
||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||
// settle down.
|
||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||
@@ -478,34 +475,31 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||
cm.hostMap.Unlock()
|
||||
}
|
||||
|
||||
// isInvalidCertificate decides if we should destroy a tunnel.
|
||||
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
||||
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
||||
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||
// check and return true.
|
||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||
remoteCert := hostinfo.GetCert()
|
||||
if remoteCert == nil {
|
||||
return false //don't tear down tunnels for handshakes in progress
|
||||
return false
|
||||
}
|
||||
|
||||
caPool := cm.intf.pki.GetCAPool()
|
||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||
if err == nil {
|
||||
return false //cert is still valid! yay!
|
||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||
// Block listed certificates should always be disconnected
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
||||
return true
|
||||
} else if cm.intf.disconnectInvalid.Load() {
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
return true
|
||||
} else {
|
||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||
return false
|
||||
}
|
||||
|
||||
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||
// Block listed certificates should always be disconnected
|
||||
return false
|
||||
}
|
||||
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
@@ -536,45 +530,15 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
cs := cm.intf.pki.getCertState()
|
||||
curCrt := hostinfo.ConnectionState.myCert
|
||||
curCrtVersion := curCrt.Version()
|
||||
myCrt := cs.getCertificate(curCrtVersion)
|
||||
if myCrt == nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("reason", "local certificate removed").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
myCrt := cs.getCertificate(curCrt.Version())
|
||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||
return
|
||||
}
|
||||
peerCrt := hostinfo.ConnectionState.peerCert
|
||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
if curCrtVersion < cs.initiatingVersion {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse {
|
||||
addrMap: map[netip.Addr]*RemoteList{},
|
||||
queryChan: make(chan netip.Addr, 10),
|
||||
}
|
||||
lighthouses := []netip.Addr{}
|
||||
lighthouses := map[netip.Addr]struct{}{}
|
||||
staticList := map[netip.Addr]struct{}{}
|
||||
|
||||
lh.lighthouses.Store(&lighthouses)
|
||||
@@ -446,10 +446,6 @@ func (d *dummyCert) PublicKey() []byte {
|
||||
return d.publicKey
|
||||
}
|
||||
|
||||
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
|
||||
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
|
||||
}
|
||||
|
||||
func (d *dummyCert) Signature() []byte {
|
||||
return d.signature
|
||||
}
|
||||
|
||||
@@ -129,109 +129,6 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
// newServer creates a nebula instance with fewer assumptions
|
||||
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
vpnNetworks := certs[len(certs)-1].Networks()
|
||||
|
||||
var udpAddr netip.AddrPort
|
||||
if vpnNetworks[0].Addr().Is4() {
|
||||
budpIp := vpnNetworks[0].Addr().As4()
|
||||
budpIp[1] -= 128
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||
} else {
|
||||
budpIp := vpnNetworks[0].Addr().As16()
|
||||
// beef for funsies
|
||||
budpIp[2] = 190
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
|
||||
caStr := ""
|
||||
for _, ca := range caCrt {
|
||||
x, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr += string(x)
|
||||
}
|
||||
certStr := ""
|
||||
for _, c := range certs {
|
||||
x, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
certStr += string(x)
|
||||
}
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": certStr,
|
||||
"key": string(key),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": m{
|
||||
"outbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
},
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": m{
|
||||
"host": udpAddr.Addr().String(),
|
||||
"port": udpAddr.Port(),
|
||||
},
|
||||
"logging": m{
|
||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
||||
"level": l.Level.String(),
|
||||
},
|
||||
"timers": m{
|
||||
"pending_deletion_interval": 2,
|
||||
"connection_alive_interval": 2,
|
||||
},
|
||||
}
|
||||
|
||||
if overrides != nil {
|
||||
final := m{}
|
||||
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mc = final
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
c := config.NewC(l)
|
||||
cStr := string(cb)
|
||||
c.LoadString(cStr)
|
||||
|
||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
type doneCb func()
|
||||
|
||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||
|
||||
@@ -4,16 +4,12 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestDropInactiveTunnels(t *testing.T) {
|
||||
@@ -59,262 +55,3 @@ func TestDropInactiveTunnels(t *testing.T) {
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertUpgrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCert2Pem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v2-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
if c == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
r.Logf("version %d", version)
|
||||
if version == cert.Version2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*10 {
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertDowngrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
|
||||
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCertPem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v1-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == cert.Version1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertMismatchCorrection(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == theirVersion {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("wtf")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
@@ -280,6 +280,47 @@ tun:
|
||||
# SO_RCVBUFFORCE is used to avoid having to raise the system wide max
|
||||
#use_system_route_table_buffer_size: 0
|
||||
|
||||
# EXPERIMENTAL: This option may change or disappear in the future.
|
||||
# Multiport spreads outgoing UDP packets across multiple UDP send ports,
|
||||
# which allows nebula to work around any issues on the underlay network.
|
||||
# Some example issues this could work around:
|
||||
# - UDP rate limits on a per flow basis.
|
||||
# - Partial underlay network failure in which some flows work and some don't
|
||||
# Agreement is done during the handshake to decide if multiport mode will
|
||||
# be used for a given tunnel (one side must have tx_enabled set, the other
|
||||
# side must have rx_enabled set)
|
||||
#
|
||||
# NOTE: you cannot use multiport on a host if you are relying on UDP hole
|
||||
# punching to get through a NAT or firewall.
|
||||
#
|
||||
# NOTE: Linux only (uses raw sockets to send). Also currently only works
|
||||
# with IPv4 underlay network remotes.
|
||||
#
|
||||
# The default values are listed below:
|
||||
#multiport:
|
||||
# This host support sending via multiple UDP ports.
|
||||
#tx_enabled: false
|
||||
#
|
||||
# This host supports receiving packets sent from multiple UDP ports.
|
||||
#rx_enabled: false
|
||||
#
|
||||
# How many UDP ports to use when sending. The lowest source port will be
|
||||
# listen.port and go up to (but not including) listen.port + tx_ports.
|
||||
#tx_ports: 100
|
||||
#
|
||||
# NOTE: All of your hosts must be running a version of Nebula that supports
|
||||
# multiport if you want to enable this feature. Older versions of Nebula
|
||||
# will be confused by these multiport handshakes.
|
||||
#
|
||||
# If handshakes are not getting a response, attempt to transmit handshakes
|
||||
# using random UDP source ports (to get around partial underlay network
|
||||
# failures).
|
||||
#tx_handshake: false
|
||||
#
|
||||
# How many unresponded handshakes we should send before we attempt to
|
||||
# send multiport handshakes.
|
||||
#tx_handshake_delay: 2
|
||||
|
||||
# Configure logging level
|
||||
logging:
|
||||
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
|
||||
|
||||
12
firewall.go
12
firewall.go
@@ -423,7 +423,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
// returns nil if the packet should not be dropped.
|
||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) error {
|
||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||
// Check if we spoke to this tuple, if we did then allow this packet
|
||||
if f.inConns(fp, h, caPool, localCache) {
|
||||
return nil
|
||||
@@ -490,9 +490,11 @@ func (f *Firewall) EmitStats() {
|
||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) bool {
|
||||
if localCache != nil && localCache.Has(fp) {
|
||||
return true
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
@@ -557,7 +559,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
conntrack.Unlock()
|
||||
|
||||
if localCache != nil {
|
||||
localCache.Add(fp)
|
||||
localCache[fp] = struct{}{}
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -10,58 +9,13 @@ import (
|
||||
|
||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||
// has been seen in the conntrack table.
|
||||
type ConntrackCache struct {
|
||||
mu sync.Mutex
|
||||
entries map[Packet]struct{}
|
||||
}
|
||||
|
||||
func newConntrackCache() *ConntrackCache {
|
||||
return &ConntrackCache{entries: make(map[Packet]struct{})}
|
||||
}
|
||||
|
||||
func (c *ConntrackCache) Has(p Packet) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
c.mu.Lock()
|
||||
_, ok := c.entries[p]
|
||||
c.mu.Unlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *ConntrackCache) Add(p Packet) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.entries[p] = struct{}{}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *ConntrackCache) Len() int {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
c.mu.Lock()
|
||||
l := len(c.entries)
|
||||
c.mu.Unlock()
|
||||
return l
|
||||
}
|
||||
|
||||
func (c *ConntrackCache) Reset(capHint int) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.entries = make(map[Packet]struct{}, capHint)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
type ConntrackCache map[Packet]struct{}
|
||||
|
||||
type ConntrackCacheTicker struct {
|
||||
cacheV uint64
|
||||
cacheTick atomic.Uint64
|
||||
|
||||
cache *ConntrackCache
|
||||
cache ConntrackCache
|
||||
}
|
||||
|
||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
||||
@@ -69,7 +23,9 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := &ConntrackCacheTicker{cache: newConntrackCache()}
|
||||
c := &ConntrackCacheTicker{
|
||||
cache: ConntrackCache{},
|
||||
}
|
||||
|
||||
go c.tick(d)
|
||||
|
||||
@@ -85,17 +41,17 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
||||
|
||||
// Get checks if the cache ticker has moved to the next version before returning
|
||||
// the map. If it has moved, we reset the map.
|
||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache {
|
||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
||||
c.cacheV = tick
|
||||
if ll := c.cache.Len(); ll > 0 {
|
||||
if ll := len(c.cache); ll > 0 {
|
||||
if l.Level == logrus.DebugLevel {
|
||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||
}
|
||||
c.cache.Reset(ll)
|
||||
c.cache = make(ConntrackCache, ll)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package firewall
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
mathrand "math/rand"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
@@ -60,3 +61,30 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
|
||||
"Fragment": fp.Fragment,
|
||||
})
|
||||
}
|
||||
|
||||
// UDPSendPort calculates the UDP port to send from when using multiport mode.
|
||||
// The result will be from [0, numBuckets)
|
||||
func (fp Packet) UDPSendPort(numBuckets int) uint16 {
|
||||
if numBuckets <= 1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// If there is no port (like an ICMP packet), pick a random UDP send port
|
||||
if fp.LocalPort == 0 {
|
||||
return uint16(mathrand.Intn(numBuckets))
|
||||
}
|
||||
|
||||
// A decent enough 32bit hash function
|
||||
// Prospecting for Hash Functions
|
||||
// - https://nullprogram.com/blog/2018/07/31/
|
||||
// - https://github.com/skeeto/hash-prospector
|
||||
// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
|
||||
x := (uint32(fp.LocalPort) << 16) | uint32(fp.RemotePort)
|
||||
x ^= x >> 16
|
||||
x *= 0x21f0aaad
|
||||
x ^= x >> 15
|
||||
x *= 0xd35a2d97
|
||||
x ^= x >> 15
|
||||
|
||||
return uint16(x) % uint16(numBuckets)
|
||||
}
|
||||
|
||||
241
firewall_test.go
241
firewall_test.go
@@ -68,9 +68,6 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
||||
require.NoError(t, err)
|
||||
|
||||
ti6, err := netip.ParsePrefix("fd12::34/128")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
// An empty rule is any
|
||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||
@@ -95,24 +92,12 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||
@@ -132,13 +117,6 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
anyIp6, err := netip.ParsePrefix("::/0")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -221,82 +199,6 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropV6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
c := dummyCert{
|
||||
name: "host1",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &cert.CachedCertificate{
|
||||
Certificate: &c,
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
},
|
||||
},
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||
}
|
||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
f := &Firewall{}
|
||||
ft := FirewallTable{
|
||||
@@ -306,10 +208,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
||||
|
||||
pfix6 := netip.MustParsePrefix("fd11::11/128")
|
||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
|
||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
b.Run("fail on proto", func(b *testing.B) {
|
||||
@@ -341,15 +239,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{},
|
||||
}
|
||||
ip := netip.MustParsePrefix("fd99::99/128")
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -363,18 +252,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -388,18 +265,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -424,17 +289,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on name", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -593,42 +447,6 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3V6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||
LocalPort: 1,
|
||||
RemotePort: 1,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
network := netip.MustParsePrefix("fd12::34/120")
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host-owner",
|
||||
networks: []netip.Prefix{network},
|
||||
},
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
// Test a remote address match
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
cp := cert.NewCAPool()
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
@@ -692,50 +510,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host-owner",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
|
||||
},
|
||||
}
|
||||
|
||||
c1 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
|
||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
|
||||
},
|
||||
}
|
||||
h1 := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c1,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("192.0.2.1"),
|
||||
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
|
||||
LocalPort: 1,
|
||||
RemotePort: 1,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||
}
|
||||
|
||||
func BenchmarkLookup(b *testing.B) {
|
||||
ml := func(m map[string]struct{}, a [][]string) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
@@ -953,21 +727,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||
|
||||
// Test adding rule with cidr ipv6
|
||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding rule with local_cidr ipv6
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_sha
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
|
||||
37
go.mod
37
go.mod
@@ -1,6 +1,8 @@
|
||||
module github.com/slackhq/nebula
|
||||
|
||||
go 1.25
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.1
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.2
|
||||
@@ -8,30 +10,30 @@ require (
|
||||
github.com/armon/go-radix v1.0.0
|
||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||
github.com/flynn/noise v1.1.0
|
||||
github.com/gaissmai/bart v0.25.0
|
||||
github.com/gaissmai/bart v0.20.4
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/kardianos/service v1.2.4
|
||||
github.com/miekg/dns v1.1.68
|
||||
github.com/kardianos/service v1.2.2
|
||||
github.com/miekg/dns v1.1.65
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/prometheus/client_golang v1.22.0
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||
golang.org/x/net v0.45.0
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/net v0.39.0
|
||||
golang.org/x/sync v0.13.0
|
||||
golang.org/x/sys v0.32.0
|
||||
golang.org/x/term v0.31.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/protobuf v1.36.8
|
||||
google.golang.org/protobuf v1.36.6
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||
)
|
||||
@@ -43,12 +45,11 @@ require (
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
golang.org/x/mod v0.24.0 // indirect
|
||||
golang.org/x/mod v0.23.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.33.0 // indirect
|
||||
golang.org/x/tools v0.30.0 // indirect
|
||||
)
|
||||
|
||||
69
go.sum
69
go.sum
@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
|
||||
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
|
||||
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
@@ -64,8 +64,8 @@ github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/
|
||||
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||
github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
|
||||
github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
|
||||
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
|
||||
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
|
||||
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
|
||||
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
|
||||
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -106,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
|
||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
|
||||
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
||||
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
|
||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
@@ -143,33 +143,29 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
||||
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -180,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -189,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -201,17 +197,18 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
||||
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
|
||||
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -222,8 +219,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -242,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
115
handshake_ix.go
115
handshake_ix.go
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
// NOISE IX Handshakes
|
||||
@@ -23,17 +24,13 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// If we're connecting to a v6 address we must use a v2 cert
|
||||
cs := f.pki.getCertState()
|
||||
v := cs.initiatingVersion
|
||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||
v = hh.initiatingVersionOverride
|
||||
} else if v < cert.Version2 {
|
||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +49,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
return false
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||
@@ -74,6 +70,15 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
},
|
||||
}
|
||||
|
||||
if f.multiPort.Tx || f.multiPort.Rx {
|
||||
hs.Details.InitiatorMultiPort = &MultiPortDetails{
|
||||
RxSupported: f.multiPort.Rx,
|
||||
TxSupported: f.multiPort.Tx,
|
||||
BasePort: uint32(f.multiPort.TxBasePort),
|
||||
TotalPorts: uint32(f.multiPort.TxPorts),
|
||||
}
|
||||
}
|
||||
|
||||
hsBytes, err := hs.Marshal()
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
@@ -108,7 +113,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", cs.initiatingVersion).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
return
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||
@@ -149,8 +153,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||
if err != nil {
|
||||
fp, fperr := rc.Fingerprint()
|
||||
if fperr != nil {
|
||||
fp, err := rc.Fingerprint()
|
||||
if err != nil {
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
@@ -169,19 +173,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if myCertOtherVersion == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithError(err).WithFields(m{
|
||||
"udpAddr": addr,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert": remoteCert,
|
||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||
}
|
||||
} else {
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = myCertOtherVersion
|
||||
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if rc == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
Info("Unable to handshake with host due to missing certificate version")
|
||||
return
|
||||
}
|
||||
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = rc
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
@@ -250,6 +251,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
return
|
||||
}
|
||||
|
||||
var multiportTx, multiportRx bool
|
||||
if f.multiPort.Rx || f.multiPort.Tx {
|
||||
if hs.Details.InitiatorMultiPort != nil {
|
||||
multiportTx = hs.Details.InitiatorMultiPort.RxSupported && f.multiPort.Tx
|
||||
multiportRx = hs.Details.InitiatorMultiPort.TxSupported && f.multiPort.Rx
|
||||
}
|
||||
|
||||
hs.Details.ResponderMultiPort = &MultiPortDetails{
|
||||
TxSupported: f.multiPort.Tx,
|
||||
RxSupported: f.multiPort.Rx,
|
||||
BasePort: uint32(f.multiPort.TxBasePort),
|
||||
TotalPorts: uint32(f.multiPort.TxPorts),
|
||||
}
|
||||
}
|
||||
if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(addr.Port()) {
|
||||
// The other side sent us a handshake from a different port, make sure
|
||||
// we send responses back to the BasePort
|
||||
addr = netip.AddrPortFrom(addr.Addr(), uint16(hs.Details.InitiatorMultiPort.BasePort))
|
||||
}
|
||||
|
||||
hostinfo := &HostInfo{
|
||||
ConnectionState: ci,
|
||||
localIndexId: myIndex,
|
||||
@@ -257,6 +278,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
vpnAddrs: vpnAddrs,
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
lastHandshakeTime: hs.Details.Time,
|
||||
multiportTx: multiportTx,
|
||||
multiportRx: multiportRx,
|
||||
relayState: RelayState{
|
||||
relays: nil,
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
@@ -271,6 +294,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("multiportTx", multiportTx).WithField("multiportRx", multiportRx).
|
||||
Info("Handshake message received")
|
||||
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
@@ -347,6 +371,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrAlreadySeen:
|
||||
if hostinfo.multiportRx {
|
||||
// The other host is sending to us with multiport, so only grab the IP
|
||||
addr = netip.AddrPortFrom(addr.Addr(), hostinfo.remote.Port())
|
||||
}
|
||||
// Update remote if preferred
|
||||
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
|
||||
// Send a test packet to ensure the other side has also switched to
|
||||
@@ -357,7 +385,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
msg = existing.HandshakePacket[2]
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
if addr.IsValid() {
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
if multiportTx {
|
||||
// TODO remove alloc here
|
||||
raw := make([]byte, len(msg)+udp.RawOverhead)
|
||||
copy(raw[udp.RawOverhead:], msg)
|
||||
err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), addr)
|
||||
} else {
|
||||
err = f.outside.WriteTo(msg, addr)
|
||||
}
|
||||
if err != nil {
|
||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
@@ -426,7 +461,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
// Do the send
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
if addr.IsValid() {
|
||||
err = f.outside.WriteTo(msg, addr)
|
||||
if multiportTx {
|
||||
// TODO remove alloc here
|
||||
raw := make([]byte, len(msg)+udp.RawOverhead)
|
||||
copy(raw[udp.RawOverhead:], msg)
|
||||
err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), addr)
|
||||
} else {
|
||||
err = f.outside.WriteTo(msg, addr)
|
||||
}
|
||||
if err != nil {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
@@ -468,7 +510,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
hostinfo.remotes.ResetBlockedRemotes()
|
||||
|
||||
return
|
||||
}
|
||||
@@ -522,6 +564,20 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
return true
|
||||
}
|
||||
|
||||
if (f.multiPort.Tx || f.multiPort.Rx) && hs.Details.ResponderMultiPort != nil {
|
||||
hostinfo.multiportTx = hs.Details.ResponderMultiPort.RxSupported && f.multiPort.Tx
|
||||
hostinfo.multiportRx = hs.Details.ResponderMultiPort.TxSupported && f.multiPort.Rx
|
||||
}
|
||||
|
||||
if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(addr.Port()) {
|
||||
// The other side sent us a handshake from a different port, make sure
|
||||
// we send responses back to the BasePort
|
||||
addr = netip.AddrPortFrom(
|
||||
addr.Addr(),
|
||||
uint16(hs.Details.ResponderMultiPort.BasePort),
|
||||
)
|
||||
}
|
||||
|
||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
@@ -653,6 +709,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("durationNs", duration).
|
||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
||||
WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx).
|
||||
Info("Handshake message received")
|
||||
|
||||
// Build up the radix for the firewall if we have subnets in the cert
|
||||
@@ -676,7 +733,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
||||
}
|
||||
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
hostinfo.remotes.ResetBlockedRemotes()
|
||||
f.metricHandshakes.Update(duration)
|
||||
|
||||
return false
|
||||
|
||||
@@ -61,6 +61,9 @@ type HandshakeManager struct {
|
||||
f *Interface
|
||||
l *logrus.Logger
|
||||
|
||||
multiPort MultiPortConfig
|
||||
udpRaw *udp.RawConn
|
||||
|
||||
// can be used to trigger outbound handshake for the given vpnIp
|
||||
trigger chan netip.Addr
|
||||
}
|
||||
@@ -68,12 +71,11 @@ type HandshakeManager struct {
|
||||
type HandshakeHostInfo struct {
|
||||
sync.Mutex
|
||||
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
|
||||
hostinfo *HostInfo
|
||||
}
|
||||
@@ -237,6 +239,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
|
||||
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
|
||||
var sentTo []netip.AddrPort
|
||||
var sentMultiport bool
|
||||
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
|
||||
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
||||
@@ -249,6 +252,27 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
} else {
|
||||
sentTo = append(sentTo, addr)
|
||||
}
|
||||
|
||||
// Attempt a multiport handshake if we are past the TxHandshakeDelay attempts
|
||||
if hm.multiPort.TxHandshake && hm.udpRaw != nil && hh.counter >= hm.multiPort.TxHandshakeDelay {
|
||||
sentMultiport = true
|
||||
// We need to re-allocate with 8 bytes at the start of SOCK_RAW
|
||||
raw := hostinfo.HandshakePacket[0x80]
|
||||
if raw == nil {
|
||||
raw = make([]byte, len(hostinfo.HandshakePacket[0])+udp.RawOverhead)
|
||||
copy(raw[udp.RawOverhead:], hostinfo.HandshakePacket[0])
|
||||
hostinfo.HandshakePacket[0x80] = raw
|
||||
}
|
||||
|
||||
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||
err = hm.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(hm.multiPort.TxPorts), addr)
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).WithField("udpAddr", addr).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
|
||||
@@ -257,6 +281,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("multiportHandshake", sentMultiport).
|
||||
Info("Handshake message sent")
|
||||
} else if hm.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
||||
|
||||
21
hostmap.go
21
hostmap.go
@@ -17,10 +17,12 @@ import (
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
// const ProbeLen = 100
|
||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
|
||||
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
|
||||
const MaxRemotes = 10
|
||||
const maxRecvError = 4
|
||||
|
||||
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
|
||||
// 5 allows for an initial handshake and each host pair re-handshaking twice
|
||||
@@ -223,12 +225,19 @@ type HostInfo struct {
|
||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
||||
// The host may have other vpn addresses that are outside our
|
||||
// vpn networks but were removed because they are not usable
|
||||
vpnAddrs []netip.Addr
|
||||
vpnAddrs []netip.Addr
|
||||
recvError atomic.Uint32
|
||||
|
||||
// networks are both all vpn and unsafe networks assigned to this host
|
||||
networks *bart.Lite
|
||||
relayState RelayState
|
||||
|
||||
// If true, we should send to this remote using multiport
|
||||
multiportTx bool
|
||||
|
||||
// If true, we will receive from this remote using multiport
|
||||
multiportRx bool
|
||||
|
||||
// HandshakePacket records the packets used to create this hostinfo
|
||||
// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
|
||||
HandshakePacket map[uint8][]byte
|
||||
@@ -730,6 +739,13 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *HostInfo) RecvErrorExceeded() bool {
|
||||
if i.recvError.Add(1) >= maxRecvError {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||
// Simple case, no CIDRTree needed
|
||||
@@ -738,8 +754,7 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
|
||||
i.networks = new(bart.Lite)
|
||||
for _, network := range networks {
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
i.networks.Insert(nprefix)
|
||||
i.networks.Insert(network)
|
||||
}
|
||||
|
||||
for _, network := range unsafeNetworks {
|
||||
|
||||
144
inside.go
144
inside.go
@@ -2,18 +2,17 @@ package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache *firewall.ConntrackCache) {
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
@@ -70,7 +69,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q, fwPacket)
|
||||
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
@@ -119,7 +118,7 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
||||
return
|
||||
}
|
||||
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q, nil)
|
||||
}
|
||||
|
||||
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
||||
@@ -230,7 +229,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
||||
return
|
||||
}
|
||||
|
||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil)
|
||||
}
|
||||
|
||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||
@@ -260,12 +259,12 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
|
||||
|
||||
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
|
||||
f.messageMetrics.Tx(t, st, 1)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil)
|
||||
}
|
||||
|
||||
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
|
||||
f.messageMetrics.Tx(t, st, 1)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0, nil)
|
||||
}
|
||||
|
||||
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
|
||||
@@ -333,25 +332,30 @@ func (f *Interface) SendVia(via *HostInfo,
|
||||
f.connectionManager.RelayUsed(relay.LocalIndex)
|
||||
}
|
||||
|
||||
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
|
||||
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, udpPortGetter udp.SendPortGetter) {
|
||||
if ci.eKey == nil {
|
||||
return
|
||||
}
|
||||
target := remote
|
||||
if !target.IsValid() {
|
||||
target = hostinfo.remote
|
||||
}
|
||||
useRelay := !target.IsValid()
|
||||
fullOut := out
|
||||
|
||||
var pkt *overlay.Packet
|
||||
if !useRelay && f.batches.Enabled() {
|
||||
pkt = f.batches.newPacket()
|
||||
if pkt != nil {
|
||||
out = pkt.Payload()[:0]
|
||||
multiport := f.multiPort.Tx && hostinfo.multiportTx
|
||||
rawOut := out
|
||||
if multiport {
|
||||
if len(out) < udp.RawOverhead {
|
||||
// NOTE: This is because some spots in the code send us `out[:0]`, so
|
||||
// we need to expand the slice back out to get our 8 bytes back.
|
||||
out = out[:udp.RawOverhead]
|
||||
}
|
||||
// Preserve bytes needed for the raw socket
|
||||
out = out[udp.RawOverhead:]
|
||||
|
||||
if udpPortGetter == nil {
|
||||
udpPortGetter = udp.RandomSendPort
|
||||
}
|
||||
}
|
||||
|
||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
||||
fullOut := out
|
||||
|
||||
if useRelay {
|
||||
if len(out) < header.Len {
|
||||
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
||||
@@ -385,85 +389,53 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
}
|
||||
|
||||
var err error
|
||||
if len(p) > 0 && slicesOverlap(out, p) {
|
||||
tmp := make([]byte, len(p))
|
||||
copy(tmp, p)
|
||||
p = tmp
|
||||
}
|
||||
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
ci.writeLock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
if pkt != nil {
|
||||
pkt.Release()
|
||||
}
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", target).WithField("counter", c).
|
||||
WithField("udpAddr", remote).WithField("counter", c).
|
||||
WithField("attemptedCounter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
return
|
||||
}
|
||||
|
||||
if target.IsValid() {
|
||||
if pkt != nil {
|
||||
pkt.Len = len(out)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithFields(logrus.Fields{
|
||||
"queue": q,
|
||||
"dest": target,
|
||||
"payload_len": pkt.Len,
|
||||
"use_batches": true,
|
||||
"remote_index": hostinfo.remoteIndexId,
|
||||
}).Debug("enqueueing packet to UDP batch queue")
|
||||
}
|
||||
if f.tryQueuePacket(q, pkt, target) {
|
||||
return
|
||||
}
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithFields(logrus.Fields{
|
||||
"queue": q,
|
||||
"dest": target,
|
||||
}).Debug("failed to enqueue packet; falling back to immediate send")
|
||||
}
|
||||
f.writeImmediatePacket(q, pkt, target, hostinfo)
|
||||
return
|
||||
if remote.IsValid() {
|
||||
if multiport {
|
||||
rawOut = rawOut[:len(out)+udp.RawOverhead]
|
||||
port := udpPortGetter.UDPSendPort(f.multiPort.TxPorts)
|
||||
err = f.udpRaw.WriteTo(rawOut, port, remote)
|
||||
} else {
|
||||
err = f.writers[q].WriteTo(out, remote)
|
||||
}
|
||||
if f.tryQueueDatagram(q, out, target) {
|
||||
return
|
||||
}
|
||||
f.writeImmediate(q, out, target, hostinfo)
|
||||
return
|
||||
}
|
||||
|
||||
// fall back to relay path
|
||||
if pkt != nil {
|
||||
pkt.Release()
|
||||
}
|
||||
|
||||
// Try to send via a relay
|
||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||
if err != nil {
|
||||
hostinfo.relayState.DeleteRelay(relayIP)
|
||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||
continue
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
if multiport {
|
||||
rawOut = rawOut[:len(out)+udp.RawOverhead]
|
||||
port := udpPortGetter.UDPSendPort(f.multiPort.TxPorts)
|
||||
err = f.udpRaw.WriteTo(rawOut, port, hostinfo.remote)
|
||||
} else {
|
||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||
}
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
} else {
|
||||
// Try to send via a relay
|
||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||
if err != nil {
|
||||
hostinfo.relayState.DeleteRelay(relayIP)
|
||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||
continue
|
||||
}
|
||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||
break
|
||||
}
|
||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// slicesOverlap reports whether the two byte slices share any portion of memory.
|
||||
// cipher.AEAD.Seal requires plaintext and dst to live in disjoint regions.
|
||||
func slicesOverlap(a, b []byte) bool {
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
aStart := uintptr(unsafe.Pointer(&a[0]))
|
||||
aEnd := aStart + uintptr(len(a))
|
||||
bStart := uintptr(unsafe.Pointer(&b[0]))
|
||||
bEnd := bStart + uintptr(len(b))
|
||||
return aStart < bEnd && bStart < aEnd
|
||||
}
|
||||
|
||||
719
interface.go
719
interface.go
@@ -8,7 +8,6 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -22,13 +21,7 @@ import (
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
mtu = 9001
|
||||
defaultGSOFlushInterval = 150 * time.Microsecond
|
||||
defaultBatchQueueDepthFactor = 4
|
||||
defaultGSOMaxSegments = 8
|
||||
maxKernelGSOSegments = 64
|
||||
)
|
||||
const mtu = 9001
|
||||
|
||||
type InterfaceConfig struct {
|
||||
HostMap *HostMap
|
||||
@@ -43,9 +36,6 @@ type InterfaceConfig struct {
|
||||
connectionManager *connectionManager
|
||||
DropLocalBroadcast bool
|
||||
DropMulticast bool
|
||||
EnableGSO bool
|
||||
EnableGRO bool
|
||||
GSOMaxSegments int
|
||||
routines int
|
||||
MessageMetrics *MessageMetrics
|
||||
version string
|
||||
@@ -57,8 +47,6 @@ type InterfaceConfig struct {
|
||||
reQueryWait time.Duration
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
BatchFlushInterval time.Duration
|
||||
BatchQueueDepth int
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
@@ -96,20 +84,12 @@ type Interface struct {
|
||||
version string
|
||||
|
||||
conntrackCacheTimeout time.Duration
|
||||
batchQueueDepth int
|
||||
enableGSO bool
|
||||
enableGRO bool
|
||||
gsoMaxSegments int
|
||||
batchUDPQueueGauge metrics.Gauge
|
||||
batchUDPFlushCounter metrics.Counter
|
||||
batchTunQueueGauge metrics.Gauge
|
||||
batchTunFlushCounter metrics.Counter
|
||||
batchFlushInterval atomic.Int64
|
||||
sendSem chan struct{}
|
||||
|
||||
writers []udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
batches batchPipelines
|
||||
udpRaw *udp.RawConn
|
||||
|
||||
multiPort MultiPortConfig
|
||||
|
||||
metricHandshakes metrics.Histogram
|
||||
messageMetrics *MessageMetrics
|
||||
@@ -118,6 +98,15 @@ type Interface struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type MultiPortConfig struct {
|
||||
Tx bool
|
||||
Rx bool
|
||||
TxBasePort uint16
|
||||
TxPorts int
|
||||
TxHandshake bool
|
||||
TxHandshakeDelay int64
|
||||
}
|
||||
|
||||
type EncWriter interface {
|
||||
SendVia(via *HostInfo,
|
||||
relay *Relay,
|
||||
@@ -184,22 +173,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
return nil, errors.New("no connection manager")
|
||||
}
|
||||
|
||||
if c.GSOMaxSegments <= 0 {
|
||||
c.GSOMaxSegments = defaultGSOMaxSegments
|
||||
}
|
||||
if c.GSOMaxSegments > maxKernelGSOSegments {
|
||||
c.GSOMaxSegments = maxKernelGSOSegments
|
||||
}
|
||||
if c.BatchQueueDepth <= 0 {
|
||||
c.BatchQueueDepth = c.routines * defaultBatchQueueDepthFactor
|
||||
}
|
||||
if c.BatchFlushInterval < 0 {
|
||||
c.BatchFlushInterval = 0
|
||||
}
|
||||
if c.BatchFlushInterval == 0 && c.EnableGSO {
|
||||
c.BatchFlushInterval = defaultGSOFlushInterval
|
||||
}
|
||||
|
||||
cs := c.pki.getCertState()
|
||||
ifce := &Interface{
|
||||
pki: c.pki,
|
||||
@@ -225,10 +198,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
relayManager: c.relayManager,
|
||||
connectionManager: c.connectionManager,
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
batchQueueDepth: c.BatchQueueDepth,
|
||||
enableGSO: c.EnableGSO,
|
||||
enableGRO: c.EnableGRO,
|
||||
gsoMaxSegments: c.GSOMaxSegments,
|
||||
|
||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
messageMetrics: c.MessageMetrics,
|
||||
@@ -241,25 +210,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
}
|
||||
|
||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||
ifce.batchUDPQueueGauge = metrics.GetOrRegisterGauge("batch.udp.queue_depth", nil)
|
||||
ifce.batchUDPFlushCounter = metrics.GetOrRegisterCounter("batch.udp.flushes", nil)
|
||||
ifce.batchTunQueueGauge = metrics.GetOrRegisterGauge("batch.tun.queue_depth", nil)
|
||||
ifce.batchTunFlushCounter = metrics.GetOrRegisterCounter("batch.tun.flushes", nil)
|
||||
ifce.batchFlushInterval.Store(int64(c.BatchFlushInterval))
|
||||
ifce.sendSem = make(chan struct{}, c.routines)
|
||||
ifce.batches.init(c.Inside, c.routines, c.BatchQueueDepth, c.GSOMaxSegments)
|
||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||
if c.l.Level >= logrus.DebugLevel {
|
||||
c.l.WithFields(logrus.Fields{
|
||||
"enableGSO": c.EnableGSO,
|
||||
"enableGRO": c.EnableGRO,
|
||||
"gsoMaxSegments": c.GSOMaxSegments,
|
||||
"batchQueueDepth": c.BatchQueueDepth,
|
||||
"batchFlush": c.BatchFlushInterval,
|
||||
"batching": ifce.batches.Enabled(),
|
||||
}).Debug("initialized batch pipelines")
|
||||
}
|
||||
|
||||
ifce.connectionManager.intf = ifce
|
||||
|
||||
@@ -284,6 +236,8 @@ func (f *Interface) activate() {
|
||||
|
||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||
|
||||
metrics.GetOrRegisterGauge("multiport.tx_ports", nil).Update(int64(f.multiPort.TxPorts))
|
||||
|
||||
// Prepare n tun queues
|
||||
var reader io.ReadWriteCloser = f.inside
|
||||
for i := 0; i < f.routines; i++ {
|
||||
@@ -308,18 +262,6 @@ func (f *Interface) run() {
|
||||
go f.listenOut(i)
|
||||
}
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("batching", f.batches.Enabled()).Debug("starting interface run loops")
|
||||
}
|
||||
|
||||
if f.batches.Enabled() {
|
||||
for i := 0; i < f.routines; i++ {
|
||||
go f.runInsideBatchWorker(i)
|
||||
go f.runTunWriteQueue(i)
|
||||
go f.runSendQueue(i)
|
||||
}
|
||||
}
|
||||
|
||||
// Launch n queues to read packets from tun dev
|
||||
for i := 0; i < f.routines; i++ {
|
||||
go f.listenIn(f.readers[i], i)
|
||||
@@ -351,17 +293,6 @@ func (f *Interface) listenOut(i int) {
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
if f.batches.Enabled() {
|
||||
if br, ok := reader.(overlay.BatchReader); ok {
|
||||
f.listenInBatchLocked(reader, br, i)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
f.listenInLegacyLocked(reader, i)
|
||||
}
|
||||
|
||||
func (f *Interface) listenInLegacyLocked(reader io.ReadWriteCloser, i int) {
|
||||
packet := make([]byte, mtu)
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
@@ -385,581 +316,6 @@ func (f *Interface) listenInLegacyLocked(reader io.ReadWriteCloser, i int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) listenInBatchLocked(raw io.ReadWriteCloser, reader overlay.BatchReader, i int) {
|
||||
pool := f.batches.Pool()
|
||||
if pool == nil {
|
||||
f.l.Warn("batch pipeline enabled without an allocated pool; falling back to single-packet reads")
|
||||
f.listenInLegacyLocked(raw, i)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
packets, err := reader.ReadIntoBatch(pool)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
if isVirtioHeadroomError(err) {
|
||||
f.l.WithError(err).Warn("Batch reader fell back due to tun headroom issue")
|
||||
f.listenInLegacyLocked(raw, i)
|
||||
return
|
||||
}
|
||||
|
||||
f.l.WithError(err).Error("Error while reading outbound packet batch")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
if len(packets) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pkt := range packets {
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
if !f.batches.enqueueRx(i, pkt) {
|
||||
pkt.Release()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) runInsideBatchWorker(i int) {
|
||||
queue := f.batches.rxQueue(i)
|
||||
if queue == nil {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
for pkt := range queue {
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
f.consumeInsidePacket(pkt.Payload(), fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||
pkt.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) runSendQueue(i int) {
|
||||
queue := f.batches.txQueue(i)
|
||||
if queue == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("queue", i).Debug("tx queue not initialized; batching disabled for writer")
|
||||
}
|
||||
return
|
||||
}
|
||||
writer := f.writerForIndex(i)
|
||||
if writer == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("queue", i).Debug("no UDP writer for batch queue")
|
||||
}
|
||||
return
|
||||
}
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("queue", i).Debug("send queue worker started")
|
||||
}
|
||||
defer func() {
|
||||
if f.l.Level >= logrus.WarnLevel {
|
||||
f.l.WithField("queue", i).Warn("send queue worker exited")
|
||||
}
|
||||
}()
|
||||
|
||||
batchCap := f.batches.batchSizeHint()
|
||||
if batchCap <= 0 {
|
||||
batchCap = 1
|
||||
}
|
||||
gsoLimit := f.effectiveGSOMaxSegments()
|
||||
if gsoLimit > batchCap {
|
||||
batchCap = gsoLimit
|
||||
}
|
||||
pending := make([]queuedDatagram, 0, batchCap)
|
||||
var (
|
||||
flushTimer *time.Timer
|
||||
flushC <-chan time.Time
|
||||
)
|
||||
dispatch := func(reason string, timerFired bool) {
|
||||
if len(pending) == 0 {
|
||||
return
|
||||
}
|
||||
batch := pending
|
||||
f.flushAndReleaseBatch(i, writer, batch, reason)
|
||||
for idx := range batch {
|
||||
batch[idx] = queuedDatagram{}
|
||||
}
|
||||
pending = pending[:0]
|
||||
if flushTimer != nil {
|
||||
if !timerFired {
|
||||
if !flushTimer.Stop() {
|
||||
select {
|
||||
case <-flushTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
flushTimer = nil
|
||||
flushC = nil
|
||||
}
|
||||
}
|
||||
armTimer := func() {
|
||||
delay := f.currentBatchFlushInterval()
|
||||
if delay <= 0 {
|
||||
dispatch("nogso", false)
|
||||
return
|
||||
}
|
||||
if flushTimer == nil {
|
||||
flushTimer = time.NewTimer(delay)
|
||||
flushC = flushTimer.C
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case d := <-queue:
|
||||
if d.packet == nil {
|
||||
continue
|
||||
}
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithFields(logrus.Fields{
|
||||
"queue": i,
|
||||
"payload_len": d.packet.Len,
|
||||
"dest": d.addr,
|
||||
}).Debug("send queue received packet")
|
||||
}
|
||||
pending = append(pending, d)
|
||||
if gsoLimit > 0 && len(pending) >= gsoLimit {
|
||||
dispatch("gso", false)
|
||||
continue
|
||||
}
|
||||
if len(pending) >= cap(pending) {
|
||||
dispatch("cap", false)
|
||||
continue
|
||||
}
|
||||
armTimer()
|
||||
f.observeUDPQueueLen(i)
|
||||
case <-flushC:
|
||||
dispatch("timer", true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) runTunWriteQueue(i int) {
|
||||
queue := f.batches.tunQueue(i)
|
||||
if queue == nil {
|
||||
return
|
||||
}
|
||||
writer := f.batches.inside
|
||||
if writer == nil {
|
||||
return
|
||||
}
|
||||
requiredHeadroom := writer.BatchHeadroom()
|
||||
|
||||
batchCap := f.batches.batchSizeHint()
|
||||
if batchCap <= 0 {
|
||||
batchCap = 1
|
||||
}
|
||||
pending := make([]*overlay.Packet, 0, batchCap)
|
||||
var (
|
||||
flushTimer *time.Timer
|
||||
flushC <-chan time.Time
|
||||
)
|
||||
flush := func(reason string, timerFired bool) {
|
||||
if len(pending) == 0 {
|
||||
return
|
||||
}
|
||||
valid := pending[:0]
|
||||
for idx := range pending {
|
||||
if !f.ensurePacketHeadroom(&pending[idx], requiredHeadroom, i, reason) {
|
||||
pending[idx] = nil
|
||||
continue
|
||||
}
|
||||
if pending[idx] != nil {
|
||||
valid = append(valid, pending[idx])
|
||||
}
|
||||
}
|
||||
if len(valid) > 0 {
|
||||
if _, err := writer.WriteBatch(valid); err != nil {
|
||||
f.l.WithError(err).
|
||||
WithField("queue", i).
|
||||
WithField("reason", reason).
|
||||
Warn("Failed to write tun batch")
|
||||
for _, pkt := range valid {
|
||||
if pkt != nil {
|
||||
f.writePacketToTun(i, pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
pending = pending[:0]
|
||||
if flushTimer != nil {
|
||||
if !timerFired {
|
||||
if !flushTimer.Stop() {
|
||||
select {
|
||||
case <-flushTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
flushTimer = nil
|
||||
flushC = nil
|
||||
}
|
||||
}
|
||||
armTimer := func() {
|
||||
delay := f.currentBatchFlushInterval()
|
||||
if delay <= 0 {
|
||||
return
|
||||
}
|
||||
if flushTimer == nil {
|
||||
flushTimer = time.NewTimer(delay)
|
||||
flushC = flushTimer.C
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case pkt := <-queue:
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
if f.ensurePacketHeadroom(&pkt, requiredHeadroom, i, "queue") {
|
||||
pending = append(pending, pkt)
|
||||
}
|
||||
if len(pending) >= cap(pending) {
|
||||
flush("cap", false)
|
||||
continue
|
||||
}
|
||||
armTimer()
|
||||
f.observeTunQueueLen(i)
|
||||
case <-flushC:
|
||||
flush("timer", true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) flushAndReleaseBatch(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
f.flushDatagrams(index, writer, batch, reason)
|
||||
for idx := range batch {
|
||||
if batch[idx].packet != nil {
|
||||
batch[idx].packet.Release()
|
||||
batch[idx].packet = nil
|
||||
}
|
||||
}
|
||||
if f.batchUDPFlushCounter != nil {
|
||||
f.batchUDPFlushCounter.Inc(int64(len(batch)))
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) flushDatagrams(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithFields(logrus.Fields{
|
||||
"writer": index,
|
||||
"reason": reason,
|
||||
"pending": len(batch),
|
||||
}).Debug("udp batch flush summary")
|
||||
}
|
||||
maxSeg := f.effectiveGSOMaxSegments()
|
||||
if bw, ok := writer.(udp.BatchConn); ok {
|
||||
chunkCap := maxSeg
|
||||
if chunkCap <= 0 {
|
||||
chunkCap = len(batch)
|
||||
}
|
||||
chunk := make([]udp.Datagram, 0, chunkCap)
|
||||
var (
|
||||
currentAddr netip.AddrPort
|
||||
segments int
|
||||
)
|
||||
flushChunk := func() {
|
||||
if len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithFields(logrus.Fields{
|
||||
"writer": index,
|
||||
"segments": len(chunk),
|
||||
"dest": chunk[0].Addr,
|
||||
"reason": reason,
|
||||
"pending_total": len(batch),
|
||||
}).Debug("flushing UDP batch")
|
||||
}
|
||||
if err := bw.WriteBatch(chunk); err != nil {
|
||||
f.l.WithError(err).
|
||||
WithField("writer", index).
|
||||
WithField("reason", reason).
|
||||
Warn("Failed to write UDP batch")
|
||||
}
|
||||
chunk = chunk[:0]
|
||||
segments = 0
|
||||
}
|
||||
for _, item := range batch {
|
||||
if item.packet == nil || !item.addr.IsValid() {
|
||||
continue
|
||||
}
|
||||
payload := item.packet.Payload()[:item.packet.Len]
|
||||
if segments == 0 {
|
||||
currentAddr = item.addr
|
||||
}
|
||||
if item.addr != currentAddr || (maxSeg > 0 && segments >= maxSeg) {
|
||||
flushChunk()
|
||||
currentAddr = item.addr
|
||||
}
|
||||
chunk = append(chunk, udp.Datagram{Payload: payload, Addr: item.addr})
|
||||
segments++
|
||||
}
|
||||
flushChunk()
|
||||
return
|
||||
}
|
||||
for _, item := range batch {
|
||||
if item.packet == nil || !item.addr.IsValid() {
|
||||
continue
|
||||
}
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithFields(logrus.Fields{
|
||||
"writer": index,
|
||||
"reason": reason,
|
||||
"dest": item.addr,
|
||||
"segments": 1,
|
||||
}).Debug("flushing UDP batch")
|
||||
}
|
||||
if err := writer.WriteTo(item.packet.Payload()[:item.packet.Len], item.addr); err != nil {
|
||||
f.l.WithError(err).
|
||||
WithField("writer", index).
|
||||
WithField("udpAddr", item.addr).
|
||||
WithField("reason", reason).
|
||||
Warn("Failed to write UDP packet")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) tryQueueDatagram(q int, buf []byte, addr netip.AddrPort) bool {
|
||||
if !addr.IsValid() || !f.batches.Enabled() {
|
||||
return false
|
||||
}
|
||||
pkt := f.batches.newPacket()
|
||||
if pkt == nil {
|
||||
return false
|
||||
}
|
||||
payload := pkt.Payload()
|
||||
if len(payload) < len(buf) {
|
||||
pkt.Release()
|
||||
return false
|
||||
}
|
||||
copy(payload, buf)
|
||||
pkt.Len = len(buf)
|
||||
if f.batches.enqueueTx(q, pkt, addr) {
|
||||
f.observeUDPQueueLen(q)
|
||||
return true
|
||||
}
|
||||
pkt.Release()
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *Interface) writerForIndex(i int) udp.Conn {
|
||||
if i < 0 || i >= len(f.writers) {
|
||||
return nil
|
||||
}
|
||||
return f.writers[i]
|
||||
}
|
||||
|
||||
func (f *Interface) writeImmediate(q int, buf []byte, addr netip.AddrPort, hostinfo *HostInfo) {
|
||||
writer := f.writerForIndex(q)
|
||||
if writer == nil {
|
||||
f.l.WithField("udpAddr", addr).
|
||||
WithField("writer", q).
|
||||
Error("Failed to write outgoing packet: no writer available")
|
||||
return
|
||||
}
|
||||
if err := writer.WriteTo(buf, addr); err != nil {
|
||||
hostinfo.logger(f.l).
|
||||
WithError(err).
|
||||
WithField("udpAddr", addr).
|
||||
Error("Failed to write outgoing packet")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) tryQueuePacket(q int, pkt *overlay.Packet, addr netip.AddrPort) bool {
|
||||
if pkt == nil || !addr.IsValid() || !f.batches.Enabled() {
|
||||
return false
|
||||
}
|
||||
if f.batches.enqueueTx(q, pkt, addr) {
|
||||
f.observeUDPQueueLen(q)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *Interface) writeImmediatePacket(q int, pkt *overlay.Packet, addr netip.AddrPort, hostinfo *HostInfo) {
|
||||
if pkt == nil {
|
||||
return
|
||||
}
|
||||
writer := f.writerForIndex(q)
|
||||
if writer == nil {
|
||||
f.l.WithField("udpAddr", addr).
|
||||
WithField("writer", q).
|
||||
Error("Failed to write outgoing packet: no writer available")
|
||||
pkt.Release()
|
||||
return
|
||||
}
|
||||
if err := writer.WriteTo(pkt.Payload()[:pkt.Len], addr); err != nil {
|
||||
hostinfo.logger(f.l).
|
||||
WithError(err).
|
||||
WithField("udpAddr", addr).
|
||||
Error("Failed to write outgoing packet")
|
||||
}
|
||||
pkt.Release()
|
||||
}
|
||||
|
||||
func (f *Interface) writePacketToTun(q int, pkt *overlay.Packet) {
|
||||
if pkt == nil {
|
||||
return
|
||||
}
|
||||
writer := f.readers[q]
|
||||
if writer == nil {
|
||||
pkt.Release()
|
||||
return
|
||||
}
|
||||
if bw, ok := writer.(interface {
|
||||
WriteBatch([]*overlay.Packet) (int, error)
|
||||
}); ok {
|
||||
if _, err := bw.WriteBatch([]*overlay.Packet{pkt}); err != nil {
|
||||
f.l.WithError(err).WithField("queue", q).Warn("Failed to write tun packet via batch writer")
|
||||
pkt.Release()
|
||||
}
|
||||
return
|
||||
}
|
||||
if _, err := writer.Write(pkt.Payload()[:pkt.Len]); err != nil {
|
||||
f.l.WithError(err).Error("Failed to write to tun")
|
||||
}
|
||||
pkt.Release()
|
||||
}
|
||||
|
||||
func (f *Interface) clonePacketWithHeadroom(pkt *overlay.Packet, required int) *overlay.Packet {
|
||||
if pkt == nil {
|
||||
return nil
|
||||
}
|
||||
payload := pkt.Payload()[:pkt.Len]
|
||||
if len(payload) == 0 && required <= 0 {
|
||||
return pkt
|
||||
}
|
||||
|
||||
pool := f.batches.Pool()
|
||||
if pool != nil {
|
||||
if clone := pool.Get(); clone != nil {
|
||||
if len(clone.Payload()) >= len(payload) {
|
||||
clone.Len = copy(clone.Payload(), payload)
|
||||
pkt.Release()
|
||||
return clone
|
||||
}
|
||||
clone.Release()
|
||||
}
|
||||
}
|
||||
|
||||
if required < 0 {
|
||||
required = 0
|
||||
}
|
||||
buf := make([]byte, required+len(payload))
|
||||
n := copy(buf[required:], payload)
|
||||
pkt.Release()
|
||||
return &overlay.Packet{
|
||||
Buf: buf,
|
||||
Offset: required,
|
||||
Len: n,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) observeUDPQueueLen(i int) {
|
||||
if f.batchUDPQueueGauge == nil {
|
||||
return
|
||||
}
|
||||
f.batchUDPQueueGauge.Update(int64(f.batches.txQueueLen(i)))
|
||||
}
|
||||
|
||||
func (f *Interface) observeTunQueueLen(i int) {
|
||||
if f.batchTunQueueGauge == nil {
|
||||
return
|
||||
}
|
||||
f.batchTunQueueGauge.Update(int64(f.batches.tunQueueLen(i)))
|
||||
}
|
||||
|
||||
func (f *Interface) currentBatchFlushInterval() time.Duration {
|
||||
if v := f.batchFlushInterval.Load(); v > 0 {
|
||||
return time.Duration(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (f *Interface) ensurePacketHeadroom(pkt **overlay.Packet, required int, queue int, reason string) bool {
|
||||
p := *pkt
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
if required <= 0 || p.Offset >= required {
|
||||
return true
|
||||
}
|
||||
clone := f.clonePacketWithHeadroom(p, required)
|
||||
if clone == nil {
|
||||
f.l.WithFields(logrus.Fields{
|
||||
"queue": queue,
|
||||
"reason": reason,
|
||||
}).Warn("dropping packet lacking tun headroom")
|
||||
return false
|
||||
}
|
||||
*pkt = clone
|
||||
return true
|
||||
}
|
||||
|
||||
func isVirtioHeadroomError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "headroom") || strings.Contains(msg, "virtio")
|
||||
}
|
||||
|
||||
func (f *Interface) effectiveGSOMaxSegments() int {
|
||||
max := f.gsoMaxSegments
|
||||
if max <= 0 {
|
||||
max = defaultGSOMaxSegments
|
||||
}
|
||||
if max > maxKernelGSOSegments {
|
||||
max = maxKernelGSOSegments
|
||||
}
|
||||
if !f.enableGSO {
|
||||
return 1
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
type udpOffloadConfigurator interface {
|
||||
ConfigureOffload(enableGSO, enableGRO bool, maxSegments int)
|
||||
}
|
||||
|
||||
func (f *Interface) applyOffloadConfig(enableGSO, enableGRO bool, maxSegments int) {
|
||||
if maxSegments <= 0 {
|
||||
maxSegments = defaultGSOMaxSegments
|
||||
}
|
||||
if maxSegments > maxKernelGSOSegments {
|
||||
maxSegments = maxKernelGSOSegments
|
||||
}
|
||||
f.enableGSO = enableGSO
|
||||
f.enableGRO = enableGRO
|
||||
f.gsoMaxSegments = maxSegments
|
||||
for _, writer := range f.writers {
|
||||
if cfg, ok := writer.(udpOffloadConfigurator); ok {
|
||||
cfg.ConfigureOffload(enableGSO, enableGRO, maxSegments)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||
c.RegisterReloadCallback(f.reloadFirewall)
|
||||
c.RegisterReloadCallback(f.reloadSendRecvError)
|
||||
@@ -1062,42 +418,6 @@ func (f *Interface) reloadMisc(c *config.C) {
|
||||
f.reQueryWait.Store(int64(n))
|
||||
f.l.Info("timers.requery_wait_duration has changed")
|
||||
}
|
||||
|
||||
if c.HasChanged("listen.gso_flush_timeout") {
|
||||
d := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
|
||||
if d < 0 {
|
||||
d = 0
|
||||
}
|
||||
f.batchFlushInterval.Store(int64(d))
|
||||
f.l.WithField("duration", d).Info("listen.gso_flush_timeout has changed")
|
||||
} else if c.HasChanged("batch.flush_interval") {
|
||||
d := c.GetDuration("batch.flush_interval", defaultGSOFlushInterval)
|
||||
if d < 0 {
|
||||
d = 0
|
||||
}
|
||||
f.batchFlushInterval.Store(int64(d))
|
||||
f.l.WithField("duration", d).Warn("batch.flush_interval is deprecated; use listen.gso_flush_timeout")
|
||||
}
|
||||
|
||||
if c.HasChanged("batch.queue_depth") {
|
||||
n := c.GetInt("batch.queue_depth", f.batchQueueDepth)
|
||||
if n != f.batchQueueDepth {
|
||||
f.batchQueueDepth = n
|
||||
f.l.Warn("batch.queue_depth changes require a restart to take effect")
|
||||
}
|
||||
}
|
||||
|
||||
if c.HasChanged("listen.enable_gso") || c.HasChanged("listen.enable_gro") || c.HasChanged("listen.gso_max_segments") {
|
||||
enableGSO := c.GetBool("listen.enable_gso", f.enableGSO)
|
||||
enableGRO := c.GetBool("listen.enable_gro", f.enableGRO)
|
||||
maxSeg := c.GetInt("listen.gso_max_segments", f.gsoMaxSegments)
|
||||
f.applyOffloadConfig(enableGSO, enableGRO, maxSeg)
|
||||
f.l.WithFields(logrus.Fields{
|
||||
"enableGSO": enableGSO,
|
||||
"enableGRO": enableGRO,
|
||||
"gsoMaxSegments": maxSeg,
|
||||
}).Info("listen GSO/GRO configuration updated")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
@@ -1106,6 +426,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
|
||||
udpStats := udp.NewUDPStatsEmitter(f.writers)
|
||||
|
||||
var rawStats func()
|
||||
|
||||
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
||||
certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
|
||||
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
|
||||
@@ -1124,6 +446,13 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
|
||||
certInitiatingVersion.Update(int64(defaultCrt.Version()))
|
||||
|
||||
if f.udpRaw != nil {
|
||||
if rawStats == nil {
|
||||
rawStats = udp.NewRawStatsEmitter(f.udpRaw)
|
||||
}
|
||||
rawStats()
|
||||
}
|
||||
|
||||
// Report the max certificate version we are capable of using
|
||||
if certState.v2Cert != nil {
|
||||
certMaxVersion.Update(int64(certState.v2Cert.Version()))
|
||||
|
||||
237
lighthouse.go
237
lighthouse.go
@@ -24,7 +24,6 @@ import (
|
||||
)
|
||||
|
||||
var ErrHostNotKnown = errors.New("host not known")
|
||||
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
|
||||
|
||||
type LightHouse struct {
|
||||
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
|
||||
@@ -57,7 +56,7 @@ type LightHouse struct {
|
||||
// staticList exists to avoid having a bool in each addrMap entry
|
||||
// since static should be rare
|
||||
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
||||
lighthouses atomic.Pointer[[]netip.Addr]
|
||||
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
|
||||
|
||||
interval atomic.Int64
|
||||
updateCancel context.CancelFunc
|
||||
@@ -108,7 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||
l: l,
|
||||
}
|
||||
lighthouses := make([]netip.Addr, 0)
|
||||
lighthouses := make(map[netip.Addr]struct{})
|
||||
h.lighthouses.Store(&lighthouses)
|
||||
staticList := make(map[netip.Addr]struct{})
|
||||
h.staticList.Store(&staticList)
|
||||
@@ -144,7 +143,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
|
||||
return *lh.staticList.Load()
|
||||
}
|
||||
|
||||
func (lh *LightHouse) GetLighthouses() []netip.Addr {
|
||||
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
|
||||
return *lh.lighthouses.Load()
|
||||
}
|
||||
|
||||
@@ -307,12 +306,13 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("lighthouse.hosts") {
|
||||
lhList, err := lh.parseLighthouses(c)
|
||||
lhMap := make(map[netip.Addr]struct{})
|
||||
err := lh.parseLighthouses(c, lhMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lh.lighthouses.Store(&lhList)
|
||||
lh.lighthouses.Store(&lhMap)
|
||||
if !initial {
|
||||
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
||||
lh.l.Info("lighthouse.hosts has changed")
|
||||
@@ -346,37 +346,36 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
||||
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
|
||||
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
||||
if lh.amLighthouse && len(lhs) != 0 {
|
||||
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
||||
}
|
||||
out := make([]netip.Addr, len(lhs))
|
||||
|
||||
for i, host := range lhs {
|
||||
addr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||
}
|
||||
|
||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||
}
|
||||
out[i] = addr
|
||||
lhMap[addr] = struct{}{}
|
||||
}
|
||||
|
||||
if !lh.amLighthouse && len(out) == 0 {
|
||||
if !lh.amLighthouse && len(lhMap) == 0 {
|
||||
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
|
||||
}
|
||||
|
||||
staticList := lh.GetStaticHostList()
|
||||
for i := range out {
|
||||
if _, ok := staticList[out[i]]; !ok {
|
||||
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
|
||||
for lhAddr, _ := range lhMap {
|
||||
if _, ok := staticList[lhAddr]; !ok {
|
||||
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
||||
@@ -487,7 +486,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
|
||||
lh.Lock()
|
||||
defer lh.Unlock()
|
||||
// Add an entry if we don't already have one
|
||||
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
|
||||
return lh.unlockedGetRemoteList(vpnAddrs)
|
||||
}
|
||||
|
||||
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
||||
@@ -520,15 +519,11 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
|
||||
}
|
||||
|
||||
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
||||
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
|
||||
staticList := lh.GetStaticHostList()
|
||||
for _, addr := range allVpnAddrs {
|
||||
if _, ok := staticList[addr]; ok {
|
||||
return
|
||||
}
|
||||
// First we check the static mapping
|
||||
// and do nothing if it is there
|
||||
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
// None of the VpnAddrs were present. Now we can do the deletes.
|
||||
lh.Lock()
|
||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||
if ok {
|
||||
@@ -570,7 +565,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
|
||||
am.unlockedSetHostnamesResults(hr)
|
||||
|
||||
for _, addrPort := range hr.GetAddrs() {
|
||||
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
|
||||
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
@@ -632,30 +627,23 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
|
||||
return len(calculatedV4) > 0 || len(calculatedV6) > 0
|
||||
}
|
||||
|
||||
// unlockedGetRemoteList assumes you have the lh lock
|
||||
// unlockedGetRemoteList
|
||||
// assumes you have the lh lock
|
||||
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
||||
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
|
||||
for i, addr := range allAddrs {
|
||||
am, ok := lh.addrMap[addr]
|
||||
if ok {
|
||||
if i != 0 {
|
||||
lh.addrMap[allAddrs[0]] = am
|
||||
}
|
||||
return am
|
||||
am, ok := lh.addrMap[allAddrs[0]]
|
||||
if !ok {
|
||||
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
|
||||
for _, addr := range allAddrs {
|
||||
lh.addrMap[addr] = am
|
||||
}
|
||||
}
|
||||
|
||||
am := NewRemoteList(allAddrs, lh.shouldAdd)
|
||||
for _, addr := range allAddrs {
|
||||
lh.addrMap[addr] = am
|
||||
}
|
||||
return am
|
||||
}
|
||||
|
||||
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
||||
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
||||
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
|
||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
|
||||
Trace("remoteAllowList.Allow")
|
||||
}
|
||||
if !allow {
|
||||
@@ -710,22 +698,19 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
||||
}
|
||||
|
||||
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
||||
l := lh.GetLighthouses()
|
||||
for i := range l {
|
||||
if l[i] == vpnAddr {
|
||||
return true
|
||||
}
|
||||
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
||||
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
|
||||
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
|
||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
|
||||
l := lh.GetLighthouses()
|
||||
for i := range vpnAddrs {
|
||||
for j := range l {
|
||||
if l[j] == vpnAddrs[i] {
|
||||
return true
|
||||
}
|
||||
for _, a := range vpnAddr {
|
||||
if _, ok := l[a]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
@@ -767,7 +752,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
||||
queried := 0
|
||||
lighthouses := lh.GetLighthouses()
|
||||
|
||||
for _, lhVpnAddr := range lighthouses {
|
||||
for lhVpnAddr := range lighthouses {
|
||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||
if hi != nil {
|
||||
v = hi.ConnectionState.myCert.Version()
|
||||
@@ -885,7 +870,7 @@ func (lh *LightHouse) SendUpdate() {
|
||||
updated := 0
|
||||
lighthouses := lh.GetLighthouses()
|
||||
|
||||
for _, lhVpnAddr := range lighthouses {
|
||||
for lhVpnAddr := range lighthouses {
|
||||
var v cert.Version
|
||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||
if hi != nil {
|
||||
@@ -943,6 +928,7 @@ func (lh *LightHouse) SendUpdate() {
|
||||
V4AddrPorts: v4,
|
||||
V6AddrPorts: v6,
|
||||
RelayVpnAddrs: relays,
|
||||
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1062,19 +1048,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
return
|
||||
}
|
||||
|
||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
useVersion := cert.Version1
|
||||
var queryVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
queryVpnAddr = netip.AddrFrom4(b)
|
||||
useVersion = 1
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
useVersion = 2
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
||||
Debugln("Dropping malformed HostQuery")
|
||||
}
|
||||
return
|
||||
}
|
||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1083,6 +1069,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostQueryReply
|
||||
if useVersion == cert.Version1 {
|
||||
if !queryVpnAddr.Is4() {
|
||||
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
|
||||
}
|
||||
b := queryVpnAddr.As4()
|
||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
||||
} else {
|
||||
@@ -1127,9 +1116,8 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
||||
if ok {
|
||||
whereToPunch = newDest
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
||||
}
|
||||
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
|
||||
//choosing to do nothing for now, but maybe we return an error?
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1188,17 +1176,19 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
||||
if !r.Is4() {
|
||||
continue
|
||||
}
|
||||
|
||||
b = r.As4()
|
||||
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
|
||||
}
|
||||
|
||||
} else if v == cert.Version2 {
|
||||
for _, r := range c.relay.relay {
|
||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||
}
|
||||
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
||||
}
|
||||
//TODO: CERT-V2 don't panic
|
||||
panic("unsupported version")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1208,16 +1198,18 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
||||
return
|
||||
}
|
||||
|
||||
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
||||
}
|
||||
return
|
||||
lhh.lh.Lock()
|
||||
|
||||
var certVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
certVpnAddr = netip.AddrFrom4(b)
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
}
|
||||
relays := n.Details.GetRelays()
|
||||
|
||||
lhh.lh.Lock()
|
||||
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
|
||||
am.Lock()
|
||||
lhh.lh.Unlock()
|
||||
@@ -1242,24 +1234,27 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
return
|
||||
}
|
||||
|
||||
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
|
||||
var detailsVpnAddr netip.Addr
|
||||
var useVersion cert.Version
|
||||
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
|
||||
useVersion := cert.Version1
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
detailsVpnAddr = netip.AddrFrom4(b)
|
||||
useVersion = cert.Version1
|
||||
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
useVersion = cert.Version2
|
||||
} else {
|
||||
detailsVpnAddr = netip.Addr{}
|
||||
useVersion = cert.Version2
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
||||
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
|
||||
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
|
||||
//Simple check that the host sent this not someone else
|
||||
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
||||
}
|
||||
@@ -1273,24 +1268,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
am.Lock()
|
||||
lhh.lh.Unlock()
|
||||
|
||||
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
||||
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
||||
am.unlockedSetRelay(fromVpnAddrs[0], relays)
|
||||
am.Unlock()
|
||||
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostUpdateNotificationAck
|
||||
switch useVersion {
|
||||
case cert.Version1:
|
||||
|
||||
if useVersion == cert.Version1 {
|
||||
if !fromVpnAddrs[0].Is4() {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
||||
return
|
||||
}
|
||||
vpnAddrB := fromVpnAddrs[0].As4()
|
||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
|
||||
case cert.Version2:
|
||||
// do nothing, we want to send a blank message
|
||||
default:
|
||||
} else if useVersion == cert.Version2 {
|
||||
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
|
||||
} else {
|
||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
||||
return
|
||||
}
|
||||
@@ -1308,20 +1303,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
|
||||
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
|
||||
//maybe one day we'll have a better idea, if it matters.
|
||||
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
|
||||
return
|
||||
}
|
||||
|
||||
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
empty := []byte{0}
|
||||
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
|
||||
punch := func(vpnPeer netip.AddrPort) {
|
||||
if !vpnPeer.IsValid() {
|
||||
return
|
||||
}
|
||||
@@ -1333,31 +1321,48 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
}()
|
||||
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
var logVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
logVpnAddr = netip.AddrFrom4(b)
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
}
|
||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range n.Details.V4AddrPorts {
|
||||
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||
punch(protoV4AddrPortToNetAddrPort(a))
|
||||
}
|
||||
|
||||
for _, a := range n.Details.V6AddrPorts {
|
||||
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||
punch(protoV6AddrPortToNetAddrPort(a))
|
||||
}
|
||||
|
||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||
// of a double nat or other difficult scenario, this may help establish
|
||||
// a tunnel.
|
||||
if lhh.lh.punchy.GetRespond() {
|
||||
var queryVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
queryVpnAddr = netip.AddrFrom4(b)
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
|
||||
}
|
||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||
// managed by a channel.
|
||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -1436,17 +1441,3 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
|
||||
}
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
|
||||
if d.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
|
||||
detailsVpnAddr := netip.AddrFrom4(b)
|
||||
return detailsVpnAddr, cert.Version1, nil
|
||||
} else if d.VpnAddr != nil {
|
||||
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
|
||||
return detailsVpnAddr, cert.Version2, nil
|
||||
} else {
|
||||
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,123 +493,3 @@ func Test_findNetworkUnion(t *testing.T) {
|
||||
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
||||
|
||||
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
|
||||
testStaticHost := netip.MustParseAddr("10.128.0.42")
|
||||
//myVpnIp := netip.MustParseAddr("10.128.0.2")
|
||||
|
||||
c := config.NewC(l)
|
||||
lh1 := "10.128.0.2"
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"hosts": []any{lh1},
|
||||
"interval": "1s",
|
||||
}
|
||||
|
||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||
c.Settings["static_host_map"] = map[string]any{
|
||||
lh1: []any{"1.1.1.1:4242"},
|
||||
"10.128.0.42": []any{"1.2.3.4:4242"},
|
||||
}
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
lh.ifce = &mockEncWriter{}
|
||||
|
||||
//test that we actually have the static entry:
|
||||
out := lh.Query(testStaticHost)
|
||||
assert.NotNil(t, out)
|
||||
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
|
||||
out.Rebuild([]netip.Prefix{}) //why tho
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
|
||||
//bolt on a lower numbered primary IP
|
||||
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
|
||||
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
|
||||
lh.addrMap[testSameHostNotStatic] = am
|
||||
out.Rebuild([]netip.Prefix{}) //???
|
||||
|
||||
//test that we actually have the static entry:
|
||||
out = lh.Query(testStaticHost)
|
||||
assert.NotNil(t, out)
|
||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
|
||||
//test that we actually have the static entry for BOTH:
|
||||
out2 := lh.Query(testSameHostNotStatic)
|
||||
assert.Same(t, out2, out)
|
||||
|
||||
//now do the delete
|
||||
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
|
||||
//verify
|
||||
out = lh.Query(testSameHostNotStatic)
|
||||
assert.NotNil(t, out)
|
||||
if out == nil {
|
||||
t.Fatal("expected non-nil query for the static host")
|
||||
}
|
||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
}
|
||||
|
||||
func TestLighthouse_DeletesWork(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
||||
testHost := netip.MustParseAddr("10.128.0.42")
|
||||
|
||||
c := config.NewC(l)
|
||||
lh1 := "10.128.0.2"
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"hosts": []any{lh1},
|
||||
"interval": "1s",
|
||||
}
|
||||
|
||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||
c.Settings["static_host_map"] = map[string]any{
|
||||
lh1: []any{"1.1.1.1:4242"},
|
||||
}
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
lh.ifce = &mockEncWriter{}
|
||||
|
||||
//insert the host
|
||||
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
|
||||
am.vpnAddrs = []netip.Addr{testHost}
|
||||
am.addrs = []netip.AddrPort{myUdpAddr2}
|
||||
lh.addrMap[testHost] = am
|
||||
am.Rebuild([]netip.Prefix{}) //???
|
||||
|
||||
//test that we actually have the entry:
|
||||
out := lh.Query(testHost)
|
||||
assert.NotNil(t, out)
|
||||
assert.Equal(t, out.vpnAddrs[0], testHost)
|
||||
out.Rebuild([]netip.Prefix{}) //why tho
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
|
||||
//now do the delete
|
||||
lh.DeleteVpnAddrs([]netip.Addr{testHost})
|
||||
//verify
|
||||
out = lh.Query(testHost)
|
||||
assert.Nil(t, out)
|
||||
}
|
||||
|
||||
70
main.go
70
main.go
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -144,20 +143,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
// set up our UDP listener
|
||||
udpConns := make([]udp.Conn, routines)
|
||||
port := c.GetInt("listen.port", 0)
|
||||
enableGSO := c.GetBool("listen.enable_gso", true)
|
||||
enableGRO := c.GetBool("listen.enable_gro", true)
|
||||
gsoMaxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
||||
if gsoMaxSegments <= 0 {
|
||||
gsoMaxSegments = defaultGSOMaxSegments
|
||||
}
|
||||
if gsoMaxSegments > maxKernelGSOSegments {
|
||||
gsoMaxSegments = maxKernelGSOSegments
|
||||
}
|
||||
gsoFlushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
|
||||
if gsoFlushTimeout < 0 {
|
||||
gsoFlushTimeout = 0
|
||||
}
|
||||
batchQueueDepth := c.GetInt("batch.queue_depth", 0)
|
||||
|
||||
if !configTest {
|
||||
rawListenHost := c.GetString("listen.host", "0.0.0.0")
|
||||
@@ -177,27 +162,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
listenHost = ips[0].Unmap()
|
||||
}
|
||||
|
||||
useWGDefault := runtime.GOOS == "linux"
|
||||
useWG := c.GetBool("listen.use_wireguard_stack", useWGDefault)
|
||||
var mkListener func(*logrus.Logger, netip.Addr, int, bool, int) (udp.Conn, error)
|
||||
if useWG {
|
||||
mkListener = udp.NewWireguardListener
|
||||
} else {
|
||||
mkListener = udp.NewListener
|
||||
}
|
||||
|
||||
for i := 0; i < routines; i++ {
|
||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||
udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
}
|
||||
udpServer.ReloadConfig(c)
|
||||
if cfg, ok := udpServer.(interface {
|
||||
ConfigureOffload(bool, bool, int)
|
||||
}); ok {
|
||||
cfg.ConfigureOffload(enableGSO, enableGRO, gsoMaxSegments)
|
||||
}
|
||||
udpConns[i] = udpServer
|
||||
|
||||
// If port is dynamic, discover it before the next pass through the for loop
|
||||
@@ -265,17 +236,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
||||
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
||||
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
||||
EnableGSO: enableGSO,
|
||||
EnableGRO: enableGRO,
|
||||
GSOMaxSegments: gsoMaxSegments,
|
||||
routines: routines,
|
||||
MessageMetrics: messageMetrics,
|
||||
version: buildVersion,
|
||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||
punchy: punchy,
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
BatchFlushInterval: gsoFlushTimeout,
|
||||
BatchQueueDepth: batchQueueDepth,
|
||||
l: l,
|
||||
}
|
||||
|
||||
@@ -287,9 +253,41 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
|
||||
ifce.writers = udpConns
|
||||
ifce.applyOffloadConfig(enableGSO, enableGRO, gsoMaxSegments)
|
||||
lightHouse.ifce = ifce
|
||||
|
||||
loadMultiPortConfig := func(c *config.C) {
|
||||
ifce.multiPort.Rx = c.GetBool("tun.multiport.rx_enabled", false)
|
||||
|
||||
tx := c.GetBool("tun.multiport.tx_enabled", false)
|
||||
|
||||
if tx && ifce.udpRaw == nil {
|
||||
ifce.udpRaw, err = udp.NewRawConn(l, c.GetString("listen.host", "0.0.0.0"), port, uint16(port))
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to get raw socket for tun.multiport.tx_enabled")
|
||||
ifce.udpRaw = nil
|
||||
tx = false
|
||||
}
|
||||
}
|
||||
|
||||
if tx {
|
||||
ifce.multiPort.TxBasePort = uint16(port)
|
||||
ifce.multiPort.TxPorts = c.GetInt("tun.multiport.tx_ports", 100)
|
||||
ifce.multiPort.TxHandshake = c.GetBool("tun.multiport.tx_handshake", false)
|
||||
ifce.multiPort.TxHandshakeDelay = int64(c.GetInt("tun.multiport.tx_handshake_delay", 2))
|
||||
ifce.udpRaw.ReloadConfig(c)
|
||||
}
|
||||
ifce.multiPort.Tx = tx
|
||||
|
||||
// TODO: if we upstream this, make this cleaner
|
||||
handshakeManager.udpRaw = ifce.udpRaw
|
||||
handshakeManager.multiPort = ifce.multiPort
|
||||
|
||||
l.WithField("multiPort", ifce.multiPort).Info("Multiport configured")
|
||||
}
|
||||
|
||||
loadMultiPortConfig(c)
|
||||
c.RegisterReloadCallback(loadMultiPortConfig)
|
||||
|
||||
ifce.RegisterConfigChangeCallbacks(c)
|
||||
ifce.reloadDisconnectInvalid(c)
|
||||
ifce.reloadSendRecvError(c)
|
||||
|
||||
515
nebula.pb.go
515
nebula.pb.go
@@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string {
|
||||
}
|
||||
|
||||
func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) {
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{8, 0}
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{9, 0}
|
||||
}
|
||||
|
||||
type NebulaMeta struct {
|
||||
@@ -541,20 +541,90 @@ func (m *NebulaHandshake) GetHmac() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
type MultiPortDetails struct {
|
||||
RxSupported bool `protobuf:"varint,1,opt,name=RxSupported,proto3" json:"RxSupported,omitempty"`
|
||||
TxSupported bool `protobuf:"varint,2,opt,name=TxSupported,proto3" json:"TxSupported,omitempty"`
|
||||
BasePort uint32 `protobuf:"varint,3,opt,name=BasePort,proto3" json:"BasePort,omitempty"`
|
||||
TotalPorts uint32 `protobuf:"varint,4,opt,name=TotalPorts,proto3" json:"TotalPorts,omitempty"`
|
||||
}
|
||||
|
||||
func (m *MultiPortDetails) Reset() { *m = MultiPortDetails{} }
|
||||
func (m *MultiPortDetails) String() string { return proto.CompactTextString(m) }
|
||||
func (*MultiPortDetails) ProtoMessage() {}
|
||||
func (*MultiPortDetails) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{7}
|
||||
}
|
||||
func (m *MultiPortDetails) XXX_Unmarshal(b []byte) error {
|
||||
return m.Unmarshal(b)
|
||||
}
|
||||
func (m *MultiPortDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
if deterministic {
|
||||
return xxx_messageInfo_MultiPortDetails.Marshal(b, m, deterministic)
|
||||
} else {
|
||||
b = b[:cap(b)]
|
||||
n, err := m.MarshalToSizedBuffer(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b[:n], nil
|
||||
}
|
||||
}
|
||||
func (m *MultiPortDetails) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_MultiPortDetails.Merge(m, src)
|
||||
}
|
||||
func (m *MultiPortDetails) XXX_Size() int {
|
||||
return m.Size()
|
||||
}
|
||||
func (m *MultiPortDetails) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_MultiPortDetails.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_MultiPortDetails proto.InternalMessageInfo
|
||||
|
||||
func (m *MultiPortDetails) GetRxSupported() bool {
|
||||
if m != nil {
|
||||
return m.RxSupported
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MultiPortDetails) GetTxSupported() bool {
|
||||
if m != nil {
|
||||
return m.TxSupported
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MultiPortDetails) GetBasePort() uint32 {
|
||||
if m != nil {
|
||||
return m.BasePort
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *MultiPortDetails) GetTotalPorts() uint32 {
|
||||
if m != nil {
|
||||
return m.TotalPorts
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type NebulaHandshakeDetails struct {
|
||||
Cert []byte `protobuf:"bytes,1,opt,name=Cert,proto3" json:"Cert,omitempty"`
|
||||
InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,proto3" json:"InitiatorIndex,omitempty"`
|
||||
ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"`
|
||||
Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"`
|
||||
Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"`
|
||||
CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"`
|
||||
Cert []byte `protobuf:"bytes,1,opt,name=Cert,proto3" json:"Cert,omitempty"`
|
||||
InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,proto3" json:"InitiatorIndex,omitempty"`
|
||||
ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"`
|
||||
Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"`
|
||||
Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"`
|
||||
CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"`
|
||||
InitiatorMultiPort *MultiPortDetails `protobuf:"bytes,6,opt,name=InitiatorMultiPort,proto3" json:"InitiatorMultiPort,omitempty"`
|
||||
ResponderMultiPort *MultiPortDetails `protobuf:"bytes,7,opt,name=ResponderMultiPort,proto3" json:"ResponderMultiPort,omitempty"`
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} }
|
||||
func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) }
|
||||
func (*NebulaHandshakeDetails) ProtoMessage() {}
|
||||
func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{7}
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{8}
|
||||
}
|
||||
func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error {
|
||||
return m.Unmarshal(b)
|
||||
@@ -625,6 +695,20 @@ func (m *NebulaHandshakeDetails) GetCertVersion() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) GetInitiatorMultiPort() *MultiPortDetails {
|
||||
if m != nil {
|
||||
return m.InitiatorMultiPort
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) GetResponderMultiPort() *MultiPortDetails {
|
||||
if m != nil {
|
||||
return m.ResponderMultiPort
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NebulaControl struct {
|
||||
Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"`
|
||||
InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"`
|
||||
@@ -639,7 +723,7 @@ func (m *NebulaControl) Reset() { *m = NebulaControl{} }
|
||||
func (m *NebulaControl) String() string { return proto.CompactTextString(m) }
|
||||
func (*NebulaControl) ProtoMessage() {}
|
||||
func (*NebulaControl) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{8}
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{9}
|
||||
}
|
||||
func (m *NebulaControl) XXX_Unmarshal(b []byte) error {
|
||||
return m.Unmarshal(b)
|
||||
@@ -730,6 +814,7 @@ func init() {
|
||||
proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort")
|
||||
proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing")
|
||||
proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake")
|
||||
proto.RegisterType((*MultiPortDetails)(nil), "nebula.MultiPortDetails")
|
||||
proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails")
|
||||
proto.RegisterType((*NebulaControl)(nil), "nebula.NebulaControl")
|
||||
}
|
||||
@@ -737,57 +822,61 @@ func init() {
|
||||
func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) }
|
||||
|
||||
var fileDescriptor_2d65afa7693df5ef = []byte{
|
||||
// 785 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44,
|
||||
0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54,
|
||||
0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a,
|
||||
0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c,
|
||||
0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb,
|
||||
0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed,
|
||||
0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc,
|
||||
0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d,
|
||||
0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6,
|
||||
0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65,
|
||||
0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70,
|
||||
0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb,
|
||||
0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f,
|
||||
0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b,
|
||||
0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b,
|
||||
0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6,
|
||||
0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0,
|
||||
0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b,
|
||||
0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c,
|
||||
0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18,
|
||||
0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86,
|
||||
0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88,
|
||||
0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7,
|
||||
0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a,
|
||||
0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba,
|
||||
0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00,
|
||||
0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8,
|
||||
0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1,
|
||||
0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68,
|
||||
0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50,
|
||||
0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f,
|
||||
0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91,
|
||||
0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee,
|
||||
0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b,
|
||||
0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7,
|
||||
0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab,
|
||||
0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60,
|
||||
0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e,
|
||||
0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90,
|
||||
0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64,
|
||||
0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9,
|
||||
0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef,
|
||||
0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf,
|
||||
0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f,
|
||||
0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8,
|
||||
0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65,
|
||||
0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9,
|
||||
0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94,
|
||||
0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00,
|
||||
0x00,
|
||||
// 864 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x56, 0x4f, 0x6f, 0xe3, 0x44,
|
||||
0x14, 0x8f, 0x1d, 0xe7, 0x4f, 0x5f, 0x9a, 0xac, 0x79, 0x15, 0x25, 0x5d, 0x89, 0x28, 0xf8, 0x50,
|
||||
0xad, 0x38, 0x64, 0x51, 0x5b, 0x56, 0x1c, 0xd9, 0x06, 0xa1, 0xac, 0xb4, 0xed, 0x96, 0x21, 0x14,
|
||||
0x89, 0x0b, 0x9a, 0xc6, 0x43, 0x63, 0xc5, 0xf1, 0x78, 0xed, 0x31, 0x6a, 0xbe, 0x05, 0xe2, 0xb3,
|
||||
0xf0, 0x21, 0xe0, 0xb6, 0x47, 0x4e, 0x08, 0xb5, 0x47, 0x8e, 0x7c, 0x01, 0x34, 0xe3, 0x7f, 0xe3,
|
||||
0xc4, 0x6c, 0x6f, 0xf3, 0xde, 0xef, 0xf7, 0x7b, 0xfe, 0xcd, 0x9b, 0x79, 0x93, 0xc0, 0x7e, 0xc0,
|
||||
0x6e, 0x12, 0x9f, 0x4e, 0xc2, 0x88, 0x0b, 0x8e, 0xed, 0x34, 0x72, 0xfe, 0x31, 0x01, 0x2e, 0xd5,
|
||||
0xf2, 0x82, 0x09, 0x8a, 0x27, 0x60, 0xcd, 0x37, 0x21, 0x1b, 0x1a, 0x63, 0xe3, 0xd9, 0xe0, 0x64,
|
||||
0x34, 0xc9, 0x34, 0x25, 0x63, 0x72, 0xc1, 0xe2, 0x98, 0xde, 0x32, 0xc9, 0x22, 0x8a, 0x8b, 0xa7,
|
||||
0xd0, 0xf9, 0x8a, 0x09, 0xea, 0xf9, 0xf1, 0xd0, 0x1c, 0x1b, 0xcf, 0x7a, 0x27, 0x47, 0xbb, 0xb2,
|
||||
0x8c, 0x40, 0x72, 0xa6, 0xf3, 0xaf, 0x01, 0x3d, 0xad, 0x14, 0x76, 0xc1, 0xba, 0xe4, 0x01, 0xb3,
|
||||
0x1b, 0xd8, 0x87, 0xbd, 0x19, 0x8f, 0xc5, 0x37, 0x09, 0x8b, 0x36, 0xb6, 0x81, 0x08, 0x83, 0x22,
|
||||
0x24, 0x2c, 0xf4, 0x37, 0xb6, 0x89, 0x4f, 0xe1, 0x50, 0xe6, 0xbe, 0x0b, 0x5d, 0x2a, 0xd8, 0x25,
|
||||
0x17, 0xde, 0x4f, 0xde, 0x82, 0x0a, 0x8f, 0x07, 0x76, 0x13, 0x8f, 0xe0, 0x43, 0x89, 0x5d, 0xf0,
|
||||
0x9f, 0x99, 0x5b, 0x81, 0xac, 0x1c, 0xba, 0x4a, 0x82, 0xc5, 0xb2, 0x02, 0xb5, 0x70, 0x00, 0x20,
|
||||
0xa1, 0xef, 0x97, 0x9c, 0xae, 0x3d, 0xbb, 0x8d, 0x07, 0xf0, 0xa4, 0x8c, 0xd3, 0xcf, 0x76, 0xa4,
|
||||
0xb3, 0x2b, 0x2a, 0x96, 0xd3, 0x25, 0x5b, 0xac, 0xec, 0xae, 0x74, 0x56, 0x84, 0x29, 0x65, 0x0f,
|
||||
0x3f, 0x86, 0xa3, 0x7a, 0x67, 0x2f, 0x17, 0x2b, 0x1b, 0x9c, 0x3f, 0x4c, 0xf8, 0x60, 0xa7, 0x29,
|
||||
0xe8, 0x00, 0xbc, 0xf1, 0xdd, 0xeb, 0x30, 0x78, 0xe9, 0xba, 0x91, 0x6a, 0x7d, 0xff, 0xdc, 0x1c,
|
||||
0x1a, 0x44, 0xcb, 0xe2, 0x31, 0x74, 0x72, 0x42, 0x5b, 0x35, 0x79, 0x3f, 0x6f, 0xb2, 0xcc, 0x91,
|
||||
0x1c, 0xc4, 0x09, 0xd8, 0x6f, 0x7c, 0x97, 0x30, 0x9f, 0x6e, 0xb2, 0x54, 0x3c, 0x6c, 0x8d, 0x9b,
|
||||
0x59, 0xc5, 0x1d, 0x0c, 0x4f, 0xa0, 0x5f, 0x25, 0x77, 0xc6, 0xcd, 0x9d, 0xea, 0x55, 0x0a, 0x9e,
|
||||
0x41, 0xef, 0xfa, 0x4c, 0x2e, 0xaf, 0x78, 0x24, 0xe4, 0xa1, 0x4b, 0x05, 0xe6, 0x8a, 0x12, 0x22,
|
||||
0x3a, 0x4d, 0xa9, 0x5e, 0x94, 0x2a, 0x6b, 0x4b, 0xf5, 0x42, 0x53, 0x95, 0x34, 0x1c, 0x42, 0x67,
|
||||
0xc1, 0x93, 0x40, 0xb0, 0x68, 0xd8, 0x94, 0x8d, 0x21, 0x79, 0xe8, 0x1c, 0x83, 0xa5, 0x76, 0x3c,
|
||||
0x00, 0x73, 0xe6, 0xa9, 0xae, 0x59, 0xc4, 0x9c, 0x79, 0x32, 0x7e, 0xcd, 0xd5, 0x4d, 0xb4, 0x88,
|
||||
0xf9, 0x9a, 0x3b, 0x67, 0x00, 0xa5, 0x0d, 0xc4, 0x54, 0x95, 0x76, 0x99, 0xa4, 0x15, 0x10, 0x2c,
|
||||
0x89, 0x29, 0x4d, 0x9f, 0xa8, 0xb5, 0xf3, 0x25, 0x40, 0x69, 0xe3, 0xb1, 0x6f, 0x14, 0x15, 0x9a,
|
||||
0x5a, 0x85, 0xbb, 0x7c, 0xb0, 0xae, 0xbc, 0xe0, 0xf6, 0xfd, 0x83, 0x25, 0x19, 0x35, 0x83, 0x85,
|
||||
0x60, 0xcd, 0xbd, 0x35, 0xcb, 0xbe, 0xa3, 0xd6, 0x8e, 0xb3, 0x33, 0x36, 0x52, 0x6c, 0x37, 0x70,
|
||||
0x0f, 0x5a, 0xe9, 0x25, 0x34, 0x9c, 0x1f, 0xe1, 0x49, 0x5a, 0x77, 0x46, 0x03, 0x37, 0x5e, 0xd2,
|
||||
0x15, 0xc3, 0x2f, 0xca, 0x19, 0x35, 0xd4, 0xf5, 0xd9, 0x72, 0x50, 0x30, 0xb7, 0x07, 0x55, 0x9a,
|
||||
0x98, 0xad, 0xe9, 0x42, 0x99, 0xd8, 0x27, 0x6a, 0xed, 0xfc, 0x6a, 0x80, 0x7d, 0x91, 0xf8, 0xc2,
|
||||
0x93, 0x1b, 0xcd, 0x89, 0x63, 0xe8, 0x91, 0xbb, 0x6f, 0x93, 0x30, 0xe4, 0x91, 0x60, 0xae, 0xfa,
|
||||
0x4c, 0x97, 0xe8, 0x29, 0xc9, 0x98, 0x6b, 0x0c, 0x33, 0x65, 0x68, 0x29, 0x7c, 0x0a, 0xdd, 0x73,
|
||||
0x1a, 0x33, 0xad, 0x97, 0x45, 0x8c, 0x23, 0x80, 0x39, 0x17, 0xd4, 0xcf, 0xaf, 0x8f, 0x44, 0xb5,
|
||||
0x8c, 0xf3, 0x97, 0x09, 0x87, 0xf5, 0x9b, 0x91, 0x7b, 0x98, 0xb2, 0x48, 0x28, 0x4f, 0xfb, 0x44,
|
||||
0xad, 0xf1, 0x18, 0x06, 0xaf, 0x02, 0x4f, 0x78, 0x54, 0xf0, 0xe8, 0x55, 0xe0, 0xb2, 0xbb, 0xec,
|
||||
0xf8, 0xb7, 0xb2, 0x92, 0x47, 0x58, 0x1c, 0xf2, 0xc0, 0x65, 0x19, 0x2f, 0x35, 0xb6, 0x95, 0xc5,
|
||||
0x43, 0x68, 0x4f, 0x39, 0x5f, 0x79, 0x4c, 0x59, 0xb3, 0x48, 0x16, 0x15, 0x87, 0xd8, 0x2a, 0x0f,
|
||||
0x51, 0x36, 0x42, 0x7a, 0xb8, 0x66, 0x51, 0xec, 0xf1, 0x60, 0xd8, 0x55, 0x05, 0xf5, 0x14, 0xce,
|
||||
0x00, 0x0b, 0x1f, 0x45, 0xa7, 0xb3, 0xc9, 0x1f, 0xe6, 0x47, 0xb7, 0x7d, 0x04, 0xa4, 0x46, 0x23,
|
||||
0x2b, 0x15, 0x4e, 0xcb, 0x4a, 0x9d, 0xc7, 0x2a, 0xed, 0x6a, 0x9c, 0xdf, 0x9a, 0xd0, 0x4f, 0x1b,
|
||||
0x3c, 0xe5, 0x81, 0x88, 0xb8, 0x8f, 0x9f, 0x57, 0x2e, 0xf5, 0x27, 0xd5, 0x2b, 0x95, 0x91, 0x6a,
|
||||
0xee, 0xf5, 0x67, 0x70, 0x50, 0x18, 0x55, 0x2f, 0x8b, 0xde, 0xff, 0x3a, 0x48, 0x2a, 0x0a, 0x43,
|
||||
0x9a, 0x22, 0x3d, 0x89, 0x3a, 0x08, 0x3f, 0x85, 0x41, 0xfe, 0xd6, 0xcd, 0xb9, 0x9a, 0x78, 0xab,
|
||||
0x78, 0x57, 0xb7, 0x10, 0xfd, 0xcd, 0xfc, 0x3a, 0xe2, 0x6b, 0xc5, 0x6e, 0x15, 0xec, 0x1d, 0x0c,
|
||||
0x27, 0xd0, 0xd3, 0x0b, 0xd7, 0xbd, 0xc7, 0x3a, 0xa1, 0x78, 0x63, 0x8b, 0xe2, 0x9d, 0x1a, 0x45,
|
||||
0x95, 0xe2, 0xcc, 0xfe, 0xef, 0xe7, 0xf1, 0x10, 0x70, 0x1a, 0x31, 0x2a, 0x98, 0xe2, 0x13, 0xf6,
|
||||
0x36, 0x61, 0xb1, 0xb0, 0x0d, 0xfc, 0x08, 0x0e, 0x2a, 0x79, 0xd9, 0x92, 0x98, 0xd9, 0xe6, 0xf9,
|
||||
0xe9, 0xef, 0xf7, 0x23, 0xe3, 0xdd, 0xfd, 0xc8, 0xf8, 0xfb, 0x7e, 0x64, 0xfc, 0xf2, 0x30, 0x6a,
|
||||
0xbc, 0x7b, 0x18, 0x35, 0xfe, 0x7c, 0x18, 0x35, 0x7e, 0x38, 0xba, 0xf5, 0xc4, 0x32, 0xb9, 0x99,
|
||||
0x2c, 0xf8, 0xfa, 0x79, 0xec, 0xd3, 0xc5, 0x6a, 0xf9, 0xf6, 0x79, 0x6a, 0xe9, 0xa6, 0xad, 0xfe,
|
||||
0x25, 0x9c, 0xfe, 0x17, 0x00, 0x00, 0xff, 0xff, 0xd6, 0x71, 0x5a, 0xf8, 0x35, 0x08, 0x00, 0x00,
|
||||
}
|
||||
|
||||
func (m *NebulaMeta) Marshal() (dAtA []byte, err error) {
|
||||
@@ -1114,6 +1203,59 @@ func (m *NebulaHandshake) MarshalToSizedBuffer(dAtA []byte) (int, error) {
|
||||
return len(dAtA) - i, nil
|
||||
}
|
||||
|
||||
func (m *MultiPortDetails) Marshal() (dAtA []byte, err error) {
|
||||
size := m.Size()
|
||||
dAtA = make([]byte, size)
|
||||
n, err := m.MarshalToSizedBuffer(dAtA[:size])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dAtA[:n], nil
|
||||
}
|
||||
|
||||
func (m *MultiPortDetails) MarshalTo(dAtA []byte) (int, error) {
|
||||
size := m.Size()
|
||||
return m.MarshalToSizedBuffer(dAtA[:size])
|
||||
}
|
||||
|
||||
func (m *MultiPortDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) {
|
||||
i := len(dAtA)
|
||||
_ = i
|
||||
var l int
|
||||
_ = l
|
||||
if m.TotalPorts != 0 {
|
||||
i = encodeVarintNebula(dAtA, i, uint64(m.TotalPorts))
|
||||
i--
|
||||
dAtA[i] = 0x20
|
||||
}
|
||||
if m.BasePort != 0 {
|
||||
i = encodeVarintNebula(dAtA, i, uint64(m.BasePort))
|
||||
i--
|
||||
dAtA[i] = 0x18
|
||||
}
|
||||
if m.TxSupported {
|
||||
i--
|
||||
if m.TxSupported {
|
||||
dAtA[i] = 1
|
||||
} else {
|
||||
dAtA[i] = 0
|
||||
}
|
||||
i--
|
||||
dAtA[i] = 0x10
|
||||
}
|
||||
if m.RxSupported {
|
||||
i--
|
||||
if m.RxSupported {
|
||||
dAtA[i] = 1
|
||||
} else {
|
||||
dAtA[i] = 0
|
||||
}
|
||||
i--
|
||||
dAtA[i] = 0x8
|
||||
}
|
||||
return len(dAtA) - i, nil
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) Marshal() (dAtA []byte, err error) {
|
||||
size := m.Size()
|
||||
dAtA = make([]byte, size)
|
||||
@@ -1139,6 +1281,30 @@ func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error)
|
||||
i--
|
||||
dAtA[i] = 0x40
|
||||
}
|
||||
if m.ResponderMultiPort != nil {
|
||||
{
|
||||
size, err := m.ResponderMultiPort.MarshalToSizedBuffer(dAtA[:i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
i -= size
|
||||
i = encodeVarintNebula(dAtA, i, uint64(size))
|
||||
}
|
||||
i--
|
||||
dAtA[i] = 0x3a
|
||||
}
|
||||
if m.InitiatorMultiPort != nil {
|
||||
{
|
||||
size, err := m.InitiatorMultiPort.MarshalToSizedBuffer(dAtA[:i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
i -= size
|
||||
i = encodeVarintNebula(dAtA, i, uint64(size))
|
||||
}
|
||||
i--
|
||||
dAtA[i] = 0x32
|
||||
}
|
||||
if m.Time != 0 {
|
||||
i = encodeVarintNebula(dAtA, i, uint64(m.Time))
|
||||
i--
|
||||
@@ -1392,6 +1558,27 @@ func (m *NebulaHandshake) Size() (n int) {
|
||||
return n
|
||||
}
|
||||
|
||||
func (m *MultiPortDetails) Size() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
var l int
|
||||
_ = l
|
||||
if m.RxSupported {
|
||||
n += 2
|
||||
}
|
||||
if m.TxSupported {
|
||||
n += 2
|
||||
}
|
||||
if m.BasePort != 0 {
|
||||
n += 1 + sovNebula(uint64(m.BasePort))
|
||||
}
|
||||
if m.TotalPorts != 0 {
|
||||
n += 1 + sovNebula(uint64(m.TotalPorts))
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) Size() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
@@ -1414,6 +1601,14 @@ func (m *NebulaHandshakeDetails) Size() (n int) {
|
||||
if m.Time != 0 {
|
||||
n += 1 + sovNebula(uint64(m.Time))
|
||||
}
|
||||
if m.InitiatorMultiPort != nil {
|
||||
l = m.InitiatorMultiPort.Size()
|
||||
n += 1 + l + sovNebula(uint64(l))
|
||||
}
|
||||
if m.ResponderMultiPort != nil {
|
||||
l = m.ResponderMultiPort.Size()
|
||||
n += 1 + l + sovNebula(uint64(l))
|
||||
}
|
||||
if m.CertVersion != 0 {
|
||||
n += 1 + sovNebula(uint64(m.CertVersion))
|
||||
}
|
||||
@@ -2356,6 +2551,134 @@ func (m *NebulaHandshake) Unmarshal(dAtA []byte) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *MultiPortDetails) Unmarshal(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
for iNdEx < l {
|
||||
preIndex := iNdEx
|
||||
var wire uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
wire |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
fieldNum := int32(wire >> 3)
|
||||
wireType := int(wire & 0x7)
|
||||
if wireType == 4 {
|
||||
return fmt.Errorf("proto: MultiPortDetails: wiretype end group for non-group")
|
||||
}
|
||||
if fieldNum <= 0 {
|
||||
return fmt.Errorf("proto: MultiPortDetails: illegal tag %d (wire type %d)", fieldNum, wire)
|
||||
}
|
||||
switch fieldNum {
|
||||
case 1:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field RxSupported", wireType)
|
||||
}
|
||||
var v int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
v |= int(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
m.RxSupported = bool(v != 0)
|
||||
case 2:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field TxSupported", wireType)
|
||||
}
|
||||
var v int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
v |= int(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
m.TxSupported = bool(v != 0)
|
||||
case 3:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field BasePort", wireType)
|
||||
}
|
||||
m.BasePort = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.BasePort |= uint32(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
case 4:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field TotalPorts", wireType)
|
||||
}
|
||||
m.TotalPorts = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.TotalPorts |= uint32(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := skipNebula(dAtA[iNdEx:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if (skippy < 0) || (iNdEx+skippy) < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
if (iNdEx + skippy) > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
iNdEx += skippy
|
||||
}
|
||||
}
|
||||
|
||||
if iNdEx > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
@@ -2495,6 +2818,78 @@ func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error {
|
||||
break
|
||||
}
|
||||
}
|
||||
case 6:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field InitiatorMultiPort", wireType)
|
||||
}
|
||||
var msglen int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
msglen |= int(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if msglen < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
postIndex := iNdEx + msglen
|
||||
if postIndex < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
if m.InitiatorMultiPort == nil {
|
||||
m.InitiatorMultiPort = &MultiPortDetails{}
|
||||
}
|
||||
if err := m.InitiatorMultiPort.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {
|
||||
return err
|
||||
}
|
||||
iNdEx = postIndex
|
||||
case 7:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field ResponderMultiPort", wireType)
|
||||
}
|
||||
var msglen int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
msglen |= int(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if msglen < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
postIndex := iNdEx + msglen
|
||||
if postIndex < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
if m.ResponderMultiPort == nil {
|
||||
m.ResponderMultiPort = &MultiPortDetails{}
|
||||
}
|
||||
if err := m.ResponderMultiPort.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {
|
||||
return err
|
||||
}
|
||||
iNdEx = postIndex
|
||||
case 8:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType)
|
||||
|
||||
12
nebula.proto
12
nebula.proto
@@ -65,6 +65,13 @@ message NebulaHandshake {
|
||||
bytes Hmac = 2;
|
||||
}
|
||||
|
||||
message MultiPortDetails {
|
||||
bool RxSupported = 1;
|
||||
bool TxSupported = 2;
|
||||
uint32 BasePort = 3;
|
||||
uint32 TotalPorts = 4;
|
||||
}
|
||||
|
||||
message NebulaHandshakeDetails {
|
||||
bytes Cert = 1;
|
||||
uint32 InitiatorIndex = 2;
|
||||
@@ -72,8 +79,9 @@ message NebulaHandshakeDetails {
|
||||
uint64 Cookie = 4;
|
||||
uint64 Time = 5;
|
||||
uint32 CertVersion = 8;
|
||||
// reserved for WIP multiport
|
||||
reserved 6, 7;
|
||||
|
||||
MultiPortDetails InitiatorMultiPort = 6;
|
||||
MultiPortDetails ResponderMultiPort = 7;
|
||||
}
|
||||
|
||||
message NebulaControl {
|
||||
|
||||
75
outside.go
75
outside.go
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
@@ -20,7 +19,7 @@ const (
|
||||
minFwPacketLen = 4
|
||||
)
|
||||
|
||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache *firewall.ConntrackCache) {
|
||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := h.Parse(packet)
|
||||
if err != nil {
|
||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||
@@ -62,7 +61,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
|
||||
switch h.Subtype {
|
||||
case header.MessageNone:
|
||||
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, ip, h.RemoteIndex) {
|
||||
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
|
||||
return
|
||||
}
|
||||
case header.MessageRelay:
|
||||
@@ -233,6 +232,16 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
||||
|
||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
|
||||
if udpAddr.IsValid() && hostinfo.remote != udpAddr {
|
||||
if hostinfo.multiportRx {
|
||||
// If the remote is sending with multiport, we aren't roaming unless
|
||||
// the IP has changed
|
||||
if hostinfo.remote.Addr().Compare(udpAddr.Addr()) == 0 {
|
||||
return
|
||||
}
|
||||
// Keep the port from the original hostinfo, because the remote is transmitting from multiport ports
|
||||
udpAddr = netip.AddrPortFrom(udpAddr.Addr(), hostinfo.remote.Port())
|
||||
}
|
||||
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
|
||||
hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||
return
|
||||
@@ -255,18 +264,16 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
||||
|
||||
}
|
||||
|
||||
// handleEncrypted returns true if a packet should be processed, false otherwise
|
||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
||||
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
|
||||
if ci == nil {
|
||||
// If connectionstate exists and the replay protector allows, process packet
|
||||
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
||||
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
|
||||
if addr.IsValid() {
|
||||
f.maybeSendRecvError(addr, h.RemoteIndex)
|
||||
return false
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
// If the window check fails, refuse to process the packet, but don't send a recv error
|
||||
if !ci.window.Check(f.l, h.MessageCounter) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -466,45 +473,23 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache *firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool {
|
||||
var (
|
||||
err error
|
||||
pkt *overlay.Packet
|
||||
)
|
||||
|
||||
if f.batches.tunQueue(q) != nil {
|
||||
pkt = f.batches.newPacket()
|
||||
if pkt != nil {
|
||||
out = pkt.Payload()[:0]
|
||||
}
|
||||
}
|
||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||
var err error
|
||||
|
||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||
if err != nil {
|
||||
if pkt != nil {
|
||||
pkt.Release()
|
||||
}
|
||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||
if addr.IsValid() {
|
||||
f.maybeSendRecvError(addr, recvIndex)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
err = newPacket(out, true, fwPacket)
|
||||
if err != nil {
|
||||
if pkt != nil {
|
||||
pkt.Release()
|
||||
}
|
||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||
Warnf("Error while validating inbound packet")
|
||||
return false
|
||||
}
|
||||
|
||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||
if pkt != nil {
|
||||
pkt.Release()
|
||||
}
|
||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping out of window packet")
|
||||
return false
|
||||
@@ -512,9 +497,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason != nil {
|
||||
if pkt != nil {
|
||||
pkt.Release()
|
||||
}
|
||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||
// This gives us a buffer to build the reject packet in
|
||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||
@@ -527,17 +509,8 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
}
|
||||
|
||||
f.connectionManager.In(hostinfo)
|
||||
if pkt != nil {
|
||||
pkt.Len = len(out)
|
||||
if f.batches.enqueueTun(q, pkt) {
|
||||
f.observeTunQueueLen(q)
|
||||
return true
|
||||
}
|
||||
f.writePacketToTun(q, pkt)
|
||||
return true
|
||||
}
|
||||
|
||||
if _, err = f.readers[q].Write(out); err != nil {
|
||||
_, err = f.readers[q].Write(out)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to write to tun")
|
||||
}
|
||||
return true
|
||||
@@ -574,6 +547,10 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||
return
|
||||
}
|
||||
|
||||
if !hostinfo.RecvErrorExceeded() {
|
||||
return
|
||||
}
|
||||
|
||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||
return
|
||||
|
||||
@@ -3,7 +3,6 @@ package overlay
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
@@ -16,84 +15,3 @@ type Device interface {
|
||||
RoutesFor(netip.Addr) routing.Gateways
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
}
|
||||
|
||||
// Packet represents a single packet buffer with optional headroom to carry
|
||||
// metadata (for example virtio-net headers).
|
||||
type Packet struct {
|
||||
Buf []byte
|
||||
Offset int
|
||||
Len int
|
||||
release func()
|
||||
}
|
||||
|
||||
func (p *Packet) Payload() []byte {
|
||||
return p.Buf[p.Offset : p.Offset+p.Len]
|
||||
}
|
||||
|
||||
func (p *Packet) Reset() {
|
||||
p.Len = 0
|
||||
p.Offset = 0
|
||||
p.release = nil
|
||||
}
|
||||
|
||||
func (p *Packet) Release() {
|
||||
if p.release != nil {
|
||||
p.release()
|
||||
p.release = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Packet) Capacity() int {
|
||||
return len(p.Buf) - p.Offset
|
||||
}
|
||||
|
||||
// PacketPool manages reusable buffers with headroom.
|
||||
type PacketPool struct {
|
||||
headroom int
|
||||
blksz int
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func NewPacketPool(headroom, payload int) *PacketPool {
|
||||
p := &PacketPool{headroom: headroom, blksz: headroom + payload}
|
||||
p.pool.New = func() any {
|
||||
buf := make([]byte, p.blksz)
|
||||
return &Packet{Buf: buf, Offset: headroom}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PacketPool) Get() *Packet {
|
||||
pkt := p.pool.Get().(*Packet)
|
||||
pkt.Offset = p.headroom
|
||||
pkt.Len = 0
|
||||
pkt.release = func() { p.put(pkt) }
|
||||
return pkt
|
||||
}
|
||||
|
||||
func (p *PacketPool) put(pkt *Packet) {
|
||||
pkt.Reset()
|
||||
p.pool.Put(pkt)
|
||||
}
|
||||
|
||||
// BatchReader allows reading multiple packets into a shared pool with
|
||||
// preallocated headroom (e.g. virtio-net headers).
|
||||
type BatchReader interface {
|
||||
ReadIntoBatch(pool *PacketPool) ([]*Packet, error)
|
||||
}
|
||||
|
||||
// BatchWriter writes a slice of packets that carry their own metadata.
|
||||
type BatchWriter interface {
|
||||
WriteBatch(packets []*Packet) (int, error)
|
||||
}
|
||||
|
||||
// BatchCapableDevice describes a device that can efficiently read and write
|
||||
// batches of packets with virtio headroom.
|
||||
type BatchCapableDevice interface {
|
||||
Device
|
||||
BatchReader
|
||||
BatchWriter
|
||||
BatchHeadroom() int
|
||||
BatchPayloadCap() int
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -72,51 +70,3 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
pLen := 128
|
||||
if prefix.Addr().Is4() {
|
||||
pLen = 32
|
||||
}
|
||||
|
||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||
return addr
|
||||
}
|
||||
|
||||
func flipBytes(b []byte) []byte {
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] ^= 0xFF
|
||||
}
|
||||
return b
|
||||
}
|
||||
func orBytes(a []byte, b []byte) []byte {
|
||||
ret := make([]byte, len(a))
|
||||
for i := 0; i < len(a); i++ {
|
||||
ret[i] = a[i] | b[i]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
||||
broadcast, _ := netip.AddrFromSlice(
|
||||
orBytes(
|
||||
cidr.Addr().AsSlice(),
|
||||
flipBytes(prefixToMask(cidr).AsSlice()),
|
||||
),
|
||||
)
|
||||
return broadcast
|
||||
}
|
||||
|
||||
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
||||
for _, gateway := range gateways {
|
||||
if dest.Addr().Is4() && gateway.Addr().Is4() {
|
||||
return gateway, nil
|
||||
}
|
||||
|
||||
if dest.Addr().Is6() && gateway.Addr().Is6() {
|
||||
return gateway, nil
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
@@ -294,6 +295,7 @@ func (t *tun) activate6(network netip.Prefix) error {
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
},
|
||||
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
|
||||
Flags: _IN6_IFF_NODAD,
|
||||
}
|
||||
|
||||
@@ -552,3 +554,13 @@ func (t *tun) Name() string {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
pLen := 128
|
||||
if prefix.Addr().Is4() {
|
||||
pLen = 32
|
||||
}
|
||||
|
||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||
return addr
|
||||
}
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
@@ -20,18 +22,12 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
||||
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
||||
FIODGNAME = 0x80106678
|
||||
TUNSIFMODE = 0x8004745e
|
||||
TUNSIFHEAD = 0x80047460
|
||||
OSIOCAIFADDR_IN6 = 0x8088691b
|
||||
IN6_IFF_NODAD = 0x0020
|
||||
FIODGNAME = 0x80106678
|
||||
)
|
||||
|
||||
type fiodgnameArg struct {
|
||||
@@ -41,159 +37,43 @@ type fiodgnameArg struct {
|
||||
}
|
||||
|
||||
type ifreqRename struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Name [16]byte
|
||||
Data uintptr
|
||||
}
|
||||
|
||||
type ifreqDestroy struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Name [16]byte
|
||||
pad [16]byte
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Flags uint16
|
||||
}
|
||||
|
||||
type ifreqMTU struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
MTU int32
|
||||
}
|
||||
|
||||
type addrLifetime struct {
|
||||
Expire uint64
|
||||
Preferred uint64
|
||||
Vltime uint32
|
||||
Pltime uint32
|
||||
}
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
VHid uint32
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime addrLifetime
|
||||
VHid uint32
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
devFd int
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
||||
if t.devFd < 0 {
|
||||
return -1, syscall.EINVAL
|
||||
}
|
||||
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
head := make([]byte, 4)
|
||||
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&to[0], uint64(len(to))},
|
||||
}
|
||||
|
||||
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
|
||||
var err error
|
||||
if errno != 0 {
|
||||
err = syscall.Errno(errno)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
// fix bytes read number to exclude header
|
||||
bytesRead := int(n)
|
||||
if bytesRead < 0 {
|
||||
return bytesRead, err
|
||||
} else if bytesRead < 4 {
|
||||
return 0, err
|
||||
} else {
|
||||
return bytesRead - 4, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
||||
if t.devFd < 0 {
|
||||
return -1, syscall.EINVAL
|
||||
}
|
||||
|
||||
if len(from) <= 1 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
ipVer := from[0] >> 4
|
||||
var head []byte
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
if ipVer == 4 {
|
||||
head = []byte{0, 0, 0, syscall.AF_INET}
|
||||
} else if ipVer == 6 {
|
||||
head = []byte{0, 0, 0, syscall.AF_INET6}
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&from[0], uint64(len(from))},
|
||||
}
|
||||
|
||||
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
|
||||
var err error
|
||||
if errno != 0 {
|
||||
err = syscall.Errno(errno)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
|
||||
return int(n) - 4, err
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.devFd >= 0 {
|
||||
err := syscall.Close(t.devFd)
|
||||
if t.ReadWriteCloser != nil {
|
||||
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Error closing device")
|
||||
return err
|
||||
}
|
||||
t.devFd = -1
|
||||
defer syscall.Close(s)
|
||||
|
||||
c := make(chan struct{})
|
||||
go func() {
|
||||
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
||||
defer close(c)
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err == nil {
|
||||
defer syscall.Close(s)
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
}
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Error destroying tunnel")
|
||||
}
|
||||
}()
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
|
||||
// wait up to 1 second so we start blocking at the ioctl
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
// Destroy the interface
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -205,37 +85,32 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var fd int
|
||||
var file *os.File
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName != "" {
|
||||
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
}
|
||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||
// If the device doesn't already exist, request a new one and rename it
|
||||
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
||||
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the name of the interface
|
||||
rawConn, err := file.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SyscallConn: %v", err)
|
||||
}
|
||||
|
||||
var name [16]byte
|
||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||
|
||||
if ctrlErr == nil {
|
||||
// set broadcast mode and multicast
|
||||
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
||||
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
||||
}
|
||||
|
||||
if ctrlErr == nil {
|
||||
// turn on link-layer mode, to support ipv6
|
||||
ifhead := uint32(1)
|
||||
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
||||
}
|
||||
|
||||
var ctrlErr error
|
||||
rawConn.Control(func(fd uintptr) {
|
||||
// Read the name of the interface
|
||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||
})
|
||||
if ctrlErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -247,7 +122,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
|
||||
// If the name doesn't match the desired interface name, rename it now
|
||||
if ifName != deviceName {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
s, err := syscall.Socket(
|
||||
syscall.AF_INET,
|
||||
syscall.SOCK_DGRAM,
|
||||
syscall.IPPROTO_IP,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -270,11 +149,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
devFd: fd,
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -293,111 +172,38 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
if cidr.Addr().Is4() {
|
||||
ifr := ifreqAlias4{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
},
|
||||
DstAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: getBroadcast(cidr).As4(),
|
||||
},
|
||||
MaskAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
},
|
||||
VHid: 0,
|
||||
}
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
if cidr.Addr().Is6() {
|
||||
ifr := ifreqAlias6{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
},
|
||||
PrefixMask: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
},
|
||||
Lifetime: addrLifetime{
|
||||
Expire: 0,
|
||||
Preferred: 0,
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
},
|
||||
Flags: IN6_IFF_NODAD,
|
||||
}
|
||||
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
// Setup our default MTU
|
||||
err := t.setMTU()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
linkAddr, err := getLinkAddr(t.Device)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if linkAddr == nil {
|
||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||
}
|
||||
t.linkAddr = linkAddr
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) setMTU() error {
|
||||
// Set the MTU on the device
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
||||
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
@@ -462,16 +268,15 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := addRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -484,8 +289,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := delRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
@@ -500,120 +306,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
fmt.Println("DOING CHANGE")
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLinkAddr Gets the link address for the interface of the given name
|
||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
switch m := m.(type) {
|
||||
case *netroute.InterfaceMessage:
|
||||
if m.Name == name {
|
||||
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
||||
if ok {
|
||||
return sa, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -20,7 +19,6 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
wgtun "github.com/slackhq/nebula/wgstack/tun"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@@ -35,7 +33,6 @@ type tun struct {
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
wgDevice wgtun.Device
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
@@ -71,9 +68,7 @@ type ifreqQLEN struct {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
useWGDefault := runtime.GOOS == "linux"
|
||||
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -118,9 +113,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
useWGDefault := runtime.GOOS == "linux"
|
||||
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -130,45 +123,16 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, useWireguard bool) (*tun, error) {
|
||||
var (
|
||||
rw io.ReadWriteCloser = file
|
||||
fd = int(file.Fd())
|
||||
wgDev wgtun.Device
|
||||
)
|
||||
|
||||
if useWireguard {
|
||||
dev, err := wgtun.CreateTUNFromFile(file, c.GetInt("tun.mtu", DefaultMTU))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize wireguard tun device: %w", err)
|
||||
}
|
||||
wgDev = dev
|
||||
rw = newWireguardTunIO(dev, c.GetInt("tun.mtu", DefaultMTU))
|
||||
fd = int(dev.File().Fd())
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: rw,
|
||||
fd: fd,
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
vpnNetworks: vpnNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
l: l,
|
||||
}
|
||||
if wgDev != nil {
|
||||
t.wgDevice = wgDev
|
||||
}
|
||||
if wgDev != nil {
|
||||
// replace ioctl fd with device file descriptor to keep route management working
|
||||
file = wgDev.File()
|
||||
t.fd = int(file.Fd())
|
||||
t.ioctlFd = file.Fd()
|
||||
}
|
||||
|
||||
if t.ioctlFd == 0 {
|
||||
t.ioctlFd = file.Fd()
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
if err != nil {
|
||||
@@ -329,6 +293,7 @@ func (t *tun) addIPs(link netlink.Link) error {
|
||||
|
||||
//add all new addresses
|
||||
for i := range newAddrs {
|
||||
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
|
||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||
return err
|
||||
@@ -396,11 +361,6 @@ func (t *tun) Activate() error {
|
||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
}
|
||||
|
||||
const modeNone = 1
|
||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
||||
}
|
||||
|
||||
if err = t.addIPs(link); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -678,11 +638,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Dst == nil {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
|
||||
return
|
||||
}
|
||||
|
||||
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
||||
@@ -714,14 +669,6 @@ func (t *tun) Close() error {
|
||||
_ = t.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
if t.wgDevice != nil {
|
||||
_ = t.wgDevice.Close()
|
||||
if t.ioctlFd > 0 {
|
||||
// underlying fd already closed by the device
|
||||
t.ioctlFd = 0
|
||||
}
|
||||
}
|
||||
|
||||
if t.ioctlFd > 0 {
|
||||
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
||||
}
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import "fmt"
|
||||
|
||||
func (t *tun) batchIO() (*wireguardTunIO, bool) {
|
||||
io, ok := t.ReadWriteCloser.(*wireguardTunIO)
|
||||
return io, ok
|
||||
}
|
||||
|
||||
func (t *tun) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
||||
io, ok := t.batchIO()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("wireguard batch I/O not enabled")
|
||||
}
|
||||
return io.ReadIntoBatch(pool)
|
||||
}
|
||||
|
||||
func (t *tun) WriteBatch(packets []*Packet) (int, error) {
|
||||
io, ok := t.batchIO()
|
||||
if ok {
|
||||
return io.WriteBatch(packets)
|
||||
}
|
||||
for _, pkt := range packets {
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
if _, err := t.Write(pkt.Payload()[:pkt.Len]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
pkt.Release()
|
||||
}
|
||||
return len(packets), nil
|
||||
}
|
||||
|
||||
func (t *tun) BatchHeadroom() int {
|
||||
if io, ok := t.batchIO(); ok {
|
||||
return io.BatchHeadroom()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (t *tun) BatchPayloadCap() int {
|
||||
if io, ok := t.batchIO(); ok {
|
||||
return io.BatchPayloadCap()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (t *tun) BatchSize() int {
|
||||
if io, ok := t.batchIO(); ok {
|
||||
return io.BatchSize()
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -4,12 +4,13 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
@@ -19,42 +20,11 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
SIOCAIFADDR_IN6 = 0x8080696b
|
||||
TUNSIFHEAD = 0x80047442
|
||||
TUNSIFMODE = 0x80047458
|
||||
)
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime addrLifetime
|
||||
}
|
||||
|
||||
type ifreq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
data int
|
||||
}
|
||||
|
||||
type addrLifetime struct {
|
||||
Expire uint64
|
||||
Preferred uint64
|
||||
Vltime uint32
|
||||
Pltime uint32
|
||||
type ifreqDestroy struct {
|
||||
Name [16]byte
|
||||
pad [16]byte
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
@@ -64,18 +34,40 @@ type tun struct {
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var file *os.File
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
@@ -85,23 +77,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
|
||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
f: os.NewFile(uintptr(fd), ""),
|
||||
fd: fd,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -119,225 +105,40 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.f != nil {
|
||||
if err := t.f.Close(); err != nil {
|
||||
return fmt.Errorf("error closing tun file: %w", err)
|
||||
}
|
||||
|
||||
// t.f.Close should have handled it for us but let's be extra sure
|
||||
_ = unix.Close(t.fd)
|
||||
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifr := ifreq{Name: t.deviceBytes()}
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
rc, err := t.f.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
||||
}
|
||||
|
||||
var errno syscall.Errno
|
||||
var n uintptr
|
||||
err = rc.Read(func(fd uintptr) bool {
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
head := [4]byte{}
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&to[0], uint64(len(to))},
|
||||
}
|
||||
|
||||
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
if errno.Temporary() {
|
||||
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
if err == syscall.EBADF || err.Error() == "use of closed file" {
|
||||
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
|
||||
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
|
||||
}
|
||||
|
||||
if errno != 0 {
|
||||
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
|
||||
}
|
||||
|
||||
// fix bytes read number to exclude header
|
||||
bytesRead := int(n)
|
||||
if bytesRead < 0 {
|
||||
return bytesRead, nil
|
||||
} else if bytesRead < 4 {
|
||||
return 0, nil
|
||||
} else {
|
||||
return bytesRead - 4, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
if len(from) <= 1 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
ipVer := from[0] >> 4
|
||||
var head [4]byte
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
if ipVer == 4 {
|
||||
head[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
head[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
rc, err := t.f.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var errno syscall.Errno
|
||||
var n uintptr
|
||||
err = rc.Write(func(fd uintptr) bool {
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&from[0], uint64(len(from))},
|
||||
}
|
||||
|
||||
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
// According to NetBSD documentation for TUN, writes will only return errors in which
|
||||
// this packet will never be delivered so just go on living life.
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if errno != 0 {
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
return int(n) - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
if cidr.Addr().Is4() {
|
||||
var req ifreqAlias4
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.DstAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.MaskAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
}
|
||||
var err error
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
if cidr.Addr().Is6() {
|
||||
var req ifreqAlias6
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
}
|
||||
req.PrefixMask = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
}
|
||||
req.Lifetime = addrLifetime{
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
}
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
mode := int32(unix.IFF_BROADCAST)
|
||||
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun device mode: %w", err)
|
||||
}
|
||||
|
||||
v := 1
|
||||
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun device head: %w", err)
|
||||
}
|
||||
|
||||
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||
}
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err = t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
func (t *tun) Activate() error {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
@@ -396,23 +197,21 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
for _, r := range routes {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -425,8 +224,10 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
@@ -441,109 +242,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,50 +4,23 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
SIOCAIFADDR_IN6 = 0x8080691a
|
||||
)
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime [2]uint32
|
||||
}
|
||||
|
||||
type ifreq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
data int
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
@@ -55,46 +28,48 @@ type tun struct {
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
|
||||
io.ReadWriteCloser
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
return t.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
if !deviceNameRE.MatchString(deviceName) {
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||
}
|
||||
|
||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
if !deviceNameRE.MatchString(deviceName) {
|
||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||
}
|
||||
|
||||
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
f: os.NewFile(uintptr(fd), ""),
|
||||
fd: fd,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -112,154 +87,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.f != nil {
|
||||
if err := t.f.Close(); err != nil {
|
||||
return fmt.Errorf("error closing tun file: %w", err)
|
||||
}
|
||||
|
||||
// t.f.Close should have handled it for us but let's be extra sure
|
||||
_ = unix.Close(t.fd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.f.Read(buf)
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
if cap(buf) < len(from)+4 {
|
||||
buf = make([]byte, len(from)+4)
|
||||
t.out = buf
|
||||
}
|
||||
buf = buf[:len(from)+4]
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
buf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
buf[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
copy(buf[4:], from)
|
||||
|
||||
n, err := t.f.Write(buf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
if cidr.Addr().Is4() {
|
||||
var req ifreqAlias4
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.DstAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.MaskAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
}
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||
}
|
||||
|
||||
err = addRoute(cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if cidr.Addr().Is6() {
|
||||
var req ifreqAlias6
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
}
|
||||
req.PrefixMask = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
}
|
||||
req.Lifetime[0] = 0xffffffff
|
||||
req.Lifetime[1] = 0xffffffff
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||
}
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err = t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
@@ -297,42 +124,63 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
for _, r := range routes {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,9 +192,10 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
@@ -355,115 +204,52 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
}
|
||||
return
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||
Seq: 1,
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.ReadWriteCloser.Read(buf)
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
if cap(buf) < len(from)+4 {
|
||||
buf = make([]byte, len(from)+4)
|
||||
t.out = buf
|
||||
}
|
||||
buf = buf[:len(from)+4]
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
buf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
buf[3] = syscall.AF_INET6
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
copy(buf[4:], from)
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
n, err := t.ReadWriteCloser.Write(buf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
wgtun "github.com/slackhq/nebula/wgstack/tun"
|
||||
)
|
||||
|
||||
type wireguardTunIO struct {
|
||||
dev wgtun.Device
|
||||
mtu int
|
||||
batchSize int
|
||||
|
||||
readMu sync.Mutex
|
||||
readBuffers [][]byte
|
||||
readLens []int
|
||||
legacyBuf []byte
|
||||
|
||||
writeMu sync.Mutex
|
||||
writeBuf []byte
|
||||
writeWrap [][]byte
|
||||
writeBuffers [][]byte
|
||||
}
|
||||
|
||||
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
||||
batch := dev.BatchSize()
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
}
|
||||
if mtu <= 0 {
|
||||
mtu = DefaultMTU
|
||||
}
|
||||
return &wireguardTunIO{
|
||||
dev: dev,
|
||||
mtu: mtu,
|
||||
batchSize: batch,
|
||||
readLens: make([]int, batch),
|
||||
legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
||||
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
||||
writeWrap: make([][]byte, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) Read(p []byte) (int, error) {
|
||||
w.readMu.Lock()
|
||||
defer w.readMu.Unlock()
|
||||
|
||||
bufs := w.readBuffers
|
||||
if len(bufs) == 0 {
|
||||
bufs = [][]byte{w.legacyBuf}
|
||||
w.readBuffers = bufs
|
||||
}
|
||||
n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
length := w.readLens[0]
|
||||
copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
|
||||
return length, nil
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
||||
if len(p) > w.mtu {
|
||||
return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
|
||||
}
|
||||
w.writeMu.Lock()
|
||||
defer w.writeMu.Unlock()
|
||||
buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
|
||||
for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
|
||||
buf[i] = 0
|
||||
}
|
||||
copy(buf[wgtun.VirtioNetHdrLen:], p)
|
||||
w.writeWrap[0] = buf
|
||||
n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
||||
if pool == nil {
|
||||
return nil, fmt.Errorf("wireguard tun: packet pool is nil")
|
||||
}
|
||||
|
||||
w.readMu.Lock()
|
||||
defer w.readMu.Unlock()
|
||||
|
||||
if len(w.readBuffers) < w.batchSize {
|
||||
w.readBuffers = make([][]byte, w.batchSize)
|
||||
}
|
||||
if len(w.readLens) < w.batchSize {
|
||||
w.readLens = make([]int, w.batchSize)
|
||||
}
|
||||
|
||||
packets := make([]*Packet, w.batchSize)
|
||||
requiredHeadroom := w.BatchHeadroom()
|
||||
requiredPayload := w.BatchPayloadCap()
|
||||
headroom := 0
|
||||
for i := 0; i < w.batchSize; i++ {
|
||||
pkt := pool.Get()
|
||||
if pkt == nil {
|
||||
releasePackets(packets[:i])
|
||||
return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
|
||||
}
|
||||
if pkt.Capacity() < requiredPayload {
|
||||
pkt.Release()
|
||||
releasePackets(packets[:i])
|
||||
return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
|
||||
}
|
||||
if i == 0 {
|
||||
headroom = pkt.Offset
|
||||
if headroom < requiredHeadroom {
|
||||
pkt.Release()
|
||||
releasePackets(packets[:i])
|
||||
return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
|
||||
}
|
||||
} else if pkt.Offset != headroom {
|
||||
pkt.Release()
|
||||
releasePackets(packets[:i])
|
||||
return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
|
||||
}
|
||||
packets[i] = pkt
|
||||
w.readBuffers[i] = pkt.Buf
|
||||
}
|
||||
|
||||
n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
|
||||
if err != nil {
|
||||
releasePackets(packets)
|
||||
return nil, err
|
||||
}
|
||||
if n == 0 {
|
||||
releasePackets(packets)
|
||||
return nil, nil
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
packets[i].Len = w.readLens[i]
|
||||
}
|
||||
for i := n; i < w.batchSize; i++ {
|
||||
packets[i].Release()
|
||||
packets[i] = nil
|
||||
}
|
||||
return packets[:n], nil
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
|
||||
if len(packets) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
requiredHeadroom := w.BatchHeadroom()
|
||||
offset := packets[0].Offset
|
||||
if offset < requiredHeadroom {
|
||||
releasePackets(packets)
|
||||
return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
|
||||
}
|
||||
for _, pkt := range packets {
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
if pkt.Offset != offset {
|
||||
releasePackets(packets)
|
||||
return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
|
||||
}
|
||||
limit := pkt.Offset + pkt.Len
|
||||
if limit > len(pkt.Buf) {
|
||||
releasePackets(packets)
|
||||
return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
|
||||
}
|
||||
}
|
||||
w.writeMu.Lock()
|
||||
defer w.writeMu.Unlock()
|
||||
|
||||
if len(w.writeBuffers) < len(packets) {
|
||||
w.writeBuffers = make([][]byte, len(packets))
|
||||
}
|
||||
for i, pkt := range packets {
|
||||
if pkt == nil {
|
||||
w.writeBuffers[i] = nil
|
||||
continue
|
||||
}
|
||||
limit := pkt.Offset + pkt.Len
|
||||
w.writeBuffers[i] = pkt.Buf[:limit]
|
||||
}
|
||||
n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
releasePackets(packets)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) BatchHeadroom() int {
|
||||
return wgtun.VirtioNetHdrLen
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) BatchPayloadCap() int {
|
||||
return w.mtu
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) BatchSize() int {
|
||||
return w.batchSize
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func releasePackets(pkts []*Packet) {
|
||||
for _, pkt := range pkts {
|
||||
if pkt != nil {
|
||||
pkt.Release()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -180,7 +180,6 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
|
||||
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
|
||||
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
|
||||
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
|
||||
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
|
||||
}
|
||||
|
||||
// Set up the parameters which include the peer's public key
|
||||
|
||||
88
pki.go
88
pki.go
@@ -100,62 +100,55 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
currentState := p.cs.Load()
|
||||
if newState.v1Cert != nil {
|
||||
if currentState.v1Cert == nil {
|
||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||
} else {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new v1 cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
} else if currentState.v1Cert != nil {
|
||||
//TODO: CERT-V2 we should be able to tear this down
|
||||
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
if newState.v2Cert != nil {
|
||||
if currentState.v2Cert == nil {
|
||||
//adding certs is fine, actually
|
||||
} else {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||
}
|
||||
|
||||
} else if currentState.v2Cert != nil {
|
||||
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
||||
if newState.v1Cert == nil {
|
||||
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
||||
}
|
||||
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
||||
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
} else if currentState.v2Cert != nil {
|
||||
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
@@ -180,6 +173,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
|
||||
p.cs.Store(newState)
|
||||
|
||||
//TODO: CERT-V2 newState needs a stringer that does json
|
||||
if initial {
|
||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
||||
} else {
|
||||
@@ -365,9 +359,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
||||
}
|
||||
|
||||
if v1.Networks()[0] != v2.Networks()[0] {
|
||||
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
|
||||
}
|
||||
//TODO: CERT-V2 make sure v2 has v1s address
|
||||
|
||||
cs.initiatingVersion = dv
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ type RemoteList struct {
|
||||
// The full list of vpn addresses assigned to this host
|
||||
vpnAddrs []netip.Addr
|
||||
|
||||
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
|
||||
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
||||
addrs []netip.AddrPort
|
||||
|
||||
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
||||
@@ -201,10 +201,8 @@ type RemoteList struct {
|
||||
// For learned addresses, this is the vpnIp that sent the packet
|
||||
cache map[netip.Addr]*cache
|
||||
|
||||
hr *hostnamesResults
|
||||
|
||||
// shouldAdd is a nillable function that decides if x should be added to addrs.
|
||||
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
|
||||
hr *hostnamesResults
|
||||
shouldAdd func(netip.Addr) bool
|
||||
|
||||
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
||||
// They should not be tried again during a handshake
|
||||
@@ -215,7 +213,7 @@ type RemoteList struct {
|
||||
}
|
||||
|
||||
// NewRemoteList creates a new empty RemoteList
|
||||
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
|
||||
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
|
||||
r := &RemoteList{
|
||||
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
||||
addrs: make([]netip.AddrPort, 0),
|
||||
@@ -370,15 +368,6 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
|
||||
return c
|
||||
}
|
||||
|
||||
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
|
||||
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
|
||||
r.Lock()
|
||||
r.badRemotes = nil
|
||||
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
|
||||
copy(r.vpnAddrs, vpnAddrs)
|
||||
r.Unlock()
|
||||
}
|
||||
|
||||
// ResetBlockedRemotes locks and clears the blocked remotes list
|
||||
func (r *RemoteList) ResetBlockedRemotes() {
|
||||
r.Lock()
|
||||
@@ -588,7 +577,7 @@ func (r *RemoteList) unlockedCollect() {
|
||||
|
||||
dnsAddrs := r.hr.GetAddrs()
|
||||
for _, addr := range dnsAddrs {
|
||||
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
|
||||
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
||||
if !r.unlockedIsBad(addr) {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
12
udp/conn.go
12
udp/conn.go
@@ -22,18 +22,6 @@ type Conn interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Datagram represents a UDP payload destined to a specific address.
|
||||
type Datagram struct {
|
||||
Payload []byte
|
||||
Addr netip.AddrPort
|
||||
}
|
||||
|
||||
// BatchConn can send multiple datagrams in one syscall.
|
||||
type BatchConn interface {
|
||||
Conn
|
||||
WriteBatch(pkts []Datagram) error
|
||||
}
|
||||
|
||||
type NoopConn struct{}
|
||||
|
||||
func (NoopConn) Rebind() error {
|
||||
|
||||
@@ -310,51 +310,31 @@ func (u *StdConn) Close() error {
|
||||
}
|
||||
|
||||
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||
if len(udpConns) == 0 {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
type statsProvider struct {
|
||||
index int
|
||||
conn *StdConn
|
||||
}
|
||||
|
||||
providers := make([]statsProvider, 0, len(udpConns))
|
||||
for i, c := range udpConns {
|
||||
if sc, ok := c.(*StdConn); ok {
|
||||
providers = append(providers, statsProvider{index: i, conn: sc})
|
||||
}
|
||||
}
|
||||
|
||||
if len(providers) == 0 {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
||||
var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge
|
||||
var meminfo [unix.SK_MEMINFO_VARS]uint32
|
||||
if err := providers[0].conn.getMemInfo(&meminfo); err != nil {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
udpGauges := make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(providers))
|
||||
for i, provider := range providers {
|
||||
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", provider.index), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", provider.index), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", provider.index), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", provider.index), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", provider.index), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", provider.index), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", provider.index), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", provider.index), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", provider.index), nil),
|
||||
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
|
||||
udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
|
||||
for i := range udpConns {
|
||||
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return func() {
|
||||
for i, provider := range providers {
|
||||
if err := provider.conn.getMemInfo(&meminfo); err == nil {
|
||||
for i, gauges := range udpGauges {
|
||||
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
|
||||
for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
|
||||
udpGauges[i][j].Update(int64(meminfo[j]))
|
||||
gauges[j].Update(int64(meminfo[j]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
16
udp/udp_raw.go
Normal file
16
udp/udp_raw.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package udp
|
||||
|
||||
import mathrand "math/rand"
|
||||
|
||||
type SendPortGetter interface {
|
||||
// UDPSendPort returns the port to use
|
||||
UDPSendPort(maxPort int) uint16
|
||||
}
|
||||
|
||||
type randomSendPort struct{}
|
||||
|
||||
func (randomSendPort) UDPSendPort(maxPort int) uint16 {
|
||||
return uint16(mathrand.Intn(maxPort))
|
||||
}
|
||||
|
||||
var RandomSendPort = randomSendPort{}
|
||||
191
udp/udp_raw_linux.go
Normal file
191
udp/udp_raw_linux.go
Normal file
@@ -0,0 +1,191 @@
|
||||
//go:build !android && !e2e_testing
|
||||
// +build !android,!e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// RawOverhead is the number of bytes that need to be reserved at the start of
|
||||
// the raw bytes passed to (*RawConn).WriteTo. This is used by WriteTo to prefix
|
||||
// the IP and UDP headers.
|
||||
const RawOverhead = 28
|
||||
|
||||
type RawConn struct {
|
||||
sysFd int
|
||||
basePort uint16
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func NewRawConn(l *logrus.Logger, ip string, port int, basePort uint16) (*RawConn, error) {
|
||||
syscall.ForkLock.RLock()
|
||||
// With IPPROTO_UDP, the linux kernel tries to deliver every UDP packet
|
||||
// received in the system to our socket. This constantly overflows our
|
||||
// buffer and marks our socket as having dropped packets. This makes the
|
||||
// stats on the socket useless.
|
||||
//
|
||||
// In contrast, IPPROTO_RAW is not delivered any packets and thus our read
|
||||
// buffer will not fill up and mark as having dropped packets. The only
|
||||
// difference is that we have to assemble the IP header as well, but this
|
||||
// is fairly easy since Linux does the checksum for us.
|
||||
//
|
||||
// TODO: How to get this working with Inet6 correctly? I was having issues
|
||||
// with the source address when testing before, probably need to `bind(2)`?
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW)
|
||||
if err == nil {
|
||||
unix.CloseOnExec(fd)
|
||||
}
|
||||
syscall.ForkLock.RUnlock()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We only want to send, not recv. This will hopefully help the kernel avoid
|
||||
// wasting time on us
|
||||
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 0); err != nil {
|
||||
return nil, fmt.Errorf("unable to set SO_RCVBUF: %s", err)
|
||||
}
|
||||
|
||||
var lip [16]byte
|
||||
copy(lip[:], net.ParseIP(ip))
|
||||
|
||||
// TODO do we need to `bind(2)` so that we send from the correct address/interface?
|
||||
if err = unix.Bind(fd, &unix.SockaddrInet6{Addr: lip, Port: port}); err != nil {
|
||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||
}
|
||||
|
||||
return &RawConn{
|
||||
sysFd: fd,
|
||||
basePort: basePort,
|
||||
l: l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WriteTo must be called with raw leaving the first `udp.RawOverhead` bytes empty,
|
||||
// for the IP/UDP headers.
|
||||
func (u *RawConn) WriteTo(raw []byte, fromPort uint16, ip netip.AddrPort) error {
|
||||
var rsa unix.RawSockaddrInet4
|
||||
rsa.Family = unix.AF_INET
|
||||
rsa.Addr = ip.Addr().As4()
|
||||
|
||||
totalLen := len(raw)
|
||||
udpLen := totalLen - ipv4.HeaderLen
|
||||
|
||||
// IP header
|
||||
raw[0] = byte(ipv4.Version<<4 | (ipv4.HeaderLen >> 2 & 0x0f))
|
||||
raw[1] = 0 // tos
|
||||
binary.BigEndian.PutUint16(raw[2:4], uint16(totalLen))
|
||||
binary.BigEndian.PutUint16(raw[4:6], 0) // id (linux does it for us)
|
||||
binary.BigEndian.PutUint16(raw[6:8], 0) // frag options
|
||||
raw[8] = byte(64) // ttl
|
||||
raw[9] = byte(17) // protocol
|
||||
binary.BigEndian.PutUint16(raw[10:12], 0) // checksum (linux does it for us)
|
||||
binary.BigEndian.PutUint32(raw[12:16], 0) // src (linux does it for us)
|
||||
copy(raw[16:20], rsa.Addr[:]) // dst
|
||||
|
||||
// UDP header
|
||||
fromPort = u.basePort + fromPort
|
||||
binary.BigEndian.PutUint16(raw[20:22], uint16(fromPort)) // src port
|
||||
binary.BigEndian.PutUint16(raw[22:24], uint16(ip.Port())) // dst port
|
||||
binary.BigEndian.PutUint16(raw[24:26], uint16(udpLen)) // UDP length
|
||||
binary.BigEndian.PutUint16(raw[26:28], 0) // checksum (optional)
|
||||
|
||||
for {
|
||||
_, _, err := unix.Syscall6(
|
||||
unix.SYS_SENDTO,
|
||||
uintptr(u.sysFd),
|
||||
uintptr(unsafe.Pointer(&raw[0])),
|
||||
uintptr(len(raw)),
|
||||
uintptr(0),
|
||||
uintptr(unsafe.Pointer(&rsa)),
|
||||
uintptr(unix.SizeofSockaddrInet4),
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
return &net.OpError{Op: "sendto", Err: err}
|
||||
}
|
||||
|
||||
//TODO: handle incomplete writes
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (u *RawConn) ReloadConfig(c *config.C) {
|
||||
b := c.GetInt("listen.write_buffer", 0)
|
||||
if b <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := u.SetSendBuffer(b); err != nil {
|
||||
u.l.WithError(err).Error("Failed to set listen.write_buffer")
|
||||
return
|
||||
}
|
||||
|
||||
s, err := u.GetSendBuffer()
|
||||
if err != nil {
|
||||
u.l.WithError(err).Warn("Failed to get listen.write_buffer")
|
||||
return
|
||||
}
|
||||
|
||||
u.l.WithField("size", s).Info("listen.write_buffer was set")
|
||||
}
|
||||
|
||||
func (u *RawConn) SetSendBuffer(n int) error {
|
||||
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
|
||||
}
|
||||
|
||||
func (u *RawConn) GetSendBuffer() (int, error) {
|
||||
return unix.GetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUF)
|
||||
}
|
||||
|
||||
func (u *RawConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||
var vallen uint32 = 4 * unix.SK_MEMINFO_VARS
|
||||
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
|
||||
if err != 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRawStatsEmitter(rawConn *RawConn) func() {
|
||||
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
||||
var gauges [unix.SK_MEMINFO_VARS]metrics.Gauge
|
||||
var meminfo [unix.SK_MEMINFO_VARS]uint32
|
||||
if err := rawConn.getMemInfo(&meminfo); err == nil {
|
||||
gauges = [unix.SK_MEMINFO_VARS]metrics.Gauge{
|
||||
metrics.GetOrRegisterGauge("raw.rmem_alloc", nil),
|
||||
metrics.GetOrRegisterGauge("raw.rcvbuf", nil),
|
||||
metrics.GetOrRegisterGauge("raw.wmem_alloc", nil),
|
||||
metrics.GetOrRegisterGauge("raw.sndbuf", nil),
|
||||
metrics.GetOrRegisterGauge("raw.fwd_alloc", nil),
|
||||
metrics.GetOrRegisterGauge("raw.wmem_queued", nil),
|
||||
metrics.GetOrRegisterGauge("raw.optmem", nil),
|
||||
metrics.GetOrRegisterGauge("raw.backlog", nil),
|
||||
metrics.GetOrRegisterGauge("raw.drops", nil),
|
||||
}
|
||||
} else {
|
||||
// return no-op because we don't support SO_MEMINFO
|
||||
return func() {}
|
||||
}
|
||||
|
||||
return func() {
|
||||
if err := rawConn.getMemInfo(&meminfo); err == nil {
|
||||
for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
|
||||
gauges[j].Update(int64(meminfo[j]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
29
udp/udp_raw_unsupported.go
Normal file
29
udp/udp_raw_unsupported.go
Normal file
@@ -0,0 +1,29 @@
|
||||
//go:build !linux || android || e2e_testing
|
||||
// +build !linux android e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
const RawOverhead = 0
|
||||
|
||||
type RawConn struct{}
|
||||
|
||||
func NewRawConn(l *logrus.Logger, ip string, port int, basePort uint16) (*RawConn, error) {
|
||||
return nil, fmt.Errorf("multiport tx is not supported on %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
func (u *RawConn) WriteTo(raw []byte, fromPort uint16, addr netip.AddrPort) error {
|
||||
return fmt.Errorf("multiport tx is not supported on %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
func (u *RawConn) ReloadConfig(c *config.C) {}
|
||||
|
||||
func NewRawStatsEmitter(rawConn *RawConn) func() { return func() {} }
|
||||
@@ -1,225 +0,0 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
||||
)
|
||||
|
||||
// WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
|
||||
type WGConn struct {
|
||||
l *logrus.Logger
|
||||
bind *wgconn.StdNetBind
|
||||
recvers []wgconn.ReceiveFunc
|
||||
batch int
|
||||
reqBatch int
|
||||
localIP netip.Addr
|
||||
localPort uint16
|
||||
enableGSO bool
|
||||
enableGRO bool
|
||||
gsoMaxSeg int
|
||||
closed atomic.Bool
|
||||
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
|
||||
func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
bind := wgconn.NewStdNetBindForAddr(ip, multi)
|
||||
recvers, actualPort, err := bind.Open(uint16(port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if batch <= 0 {
|
||||
batch = bind.BatchSize()
|
||||
} else if batch > bind.BatchSize() {
|
||||
batch = bind.BatchSize()
|
||||
}
|
||||
return &WGConn{
|
||||
l: l,
|
||||
bind: bind,
|
||||
recvers: recvers,
|
||||
batch: batch,
|
||||
reqBatch: batch,
|
||||
localIP: ip,
|
||||
localPort: actualPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *WGConn) Rebind() error {
|
||||
// WireGuard's bind does not support rebinding in place.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
|
||||
if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
|
||||
// Fallback to wildcard IPv4 for display purposes.
|
||||
return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
|
||||
}
|
||||
return netip.AddrPortFrom(c.localIP, c.localPort), nil
|
||||
}
|
||||
|
||||
func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
|
||||
batchSize := c.batch
|
||||
packets := make([][]byte, batchSize)
|
||||
for i := range packets {
|
||||
packets[i] = make([]byte, MTU)
|
||||
}
|
||||
sizes := make([]int, batchSize)
|
||||
endpoints := make([]wgconn.Endpoint, batchSize)
|
||||
|
||||
for {
|
||||
if c.closed.Load() {
|
||||
return
|
||||
}
|
||||
n, err := fn(packets, sizes, endpoints)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
if c.l != nil {
|
||||
c.l.WithError(err).Debug("wireguard UDP listener receive error")
|
||||
}
|
||||
continue
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
|
||||
if !ok {
|
||||
if c.l != nil {
|
||||
c.l.Warn("wireguard UDP listener received unexpected endpoint type")
|
||||
}
|
||||
continue
|
||||
}
|
||||
addr := stdEp.AddrPort
|
||||
r(addr, packets[i][:sizes[i]])
|
||||
endpoints[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WGConn) ListenOut(r EncReader) {
|
||||
for _, fn := range c.recvers {
|
||||
go c.listen(fn, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
if c.closed.Load() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
ep := &wgconn.StdNetEndpoint{AddrPort: addr}
|
||||
return c.bind.Send([][]byte{b}, ep)
|
||||
}
|
||||
|
||||
func (c *WGConn) WriteBatch(datagrams []Datagram) error {
|
||||
if len(datagrams) == 0 {
|
||||
return nil
|
||||
}
|
||||
if c.closed.Load() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
max := c.batch
|
||||
if max <= 0 {
|
||||
max = len(datagrams)
|
||||
if max == 0 {
|
||||
max = 1
|
||||
}
|
||||
}
|
||||
bufs := make([][]byte, 0, max)
|
||||
var (
|
||||
current netip.AddrPort
|
||||
endpoint *wgconn.StdNetEndpoint
|
||||
haveAddr bool
|
||||
)
|
||||
flush := func() error {
|
||||
if len(bufs) == 0 || endpoint == nil {
|
||||
bufs = bufs[:0]
|
||||
return nil
|
||||
}
|
||||
err := c.bind.Send(bufs, endpoint)
|
||||
bufs = bufs[:0]
|
||||
return err
|
||||
}
|
||||
|
||||
for _, d := range datagrams {
|
||||
if len(d.Payload) == 0 || !d.Addr.IsValid() {
|
||||
continue
|
||||
}
|
||||
if !haveAddr || d.Addr != current {
|
||||
if err := flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
current = d.Addr
|
||||
endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
|
||||
haveAddr = true
|
||||
}
|
||||
bufs = append(bufs, d.Payload)
|
||||
if len(bufs) >= max {
|
||||
if err := flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return flush()
|
||||
}
|
||||
|
||||
func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
|
||||
c.enableGSO = enableGSO
|
||||
c.enableGRO = enableGRO
|
||||
if maxSegments <= 0 {
|
||||
maxSegments = 1
|
||||
} else if maxSegments > wgconn.IdealBatchSize {
|
||||
maxSegments = wgconn.IdealBatchSize
|
||||
}
|
||||
c.gsoMaxSeg = maxSegments
|
||||
|
||||
effectiveBatch := c.reqBatch
|
||||
if enableGSO && c.bind != nil {
|
||||
bindBatch := c.bind.BatchSize()
|
||||
if effectiveBatch < bindBatch {
|
||||
if c.l != nil {
|
||||
c.l.WithFields(logrus.Fields{
|
||||
"requested": c.reqBatch,
|
||||
"effective": bindBatch,
|
||||
}).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
|
||||
}
|
||||
effectiveBatch = bindBatch
|
||||
}
|
||||
}
|
||||
c.batch = effectiveBatch
|
||||
|
||||
if c.l != nil {
|
||||
c.l.WithFields(logrus.Fields{
|
||||
"enableGSO": enableGSO,
|
||||
"enableGRO": enableGRO,
|
||||
"gsoMaxSegments": maxSegments,
|
||||
}).Debug("configured wireguard UDP offload")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WGConn) ReloadConfig(*config.C) {
|
||||
// WireGuard bind currently does not expose runtime configuration knobs.
|
||||
}
|
||||
|
||||
func (c *WGConn) Close() error {
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
c.closed.Store(true)
|
||||
err = c.bind.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build !linux || android || e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewWireguardListener is only available on Linux builds.
|
||||
func NewWireguardListener(*logrus.Logger, netip.Addr, int, bool, int) (Conn, error) {
|
||||
return nil, fmt.Errorf("wireguard experimental UDP listener is only supported on Linux")
|
||||
}
|
||||
@@ -1,539 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
)
|
||||
|
||||
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
||||
// (see bind_windows.go), it may fall back to StdNetBind.
|
||||
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
||||
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||
type StdNetBind struct {
|
||||
mu sync.Mutex // protects all fields except as specified
|
||||
ipv4 *net.UDPConn
|
||||
ipv6 *net.UDPConn
|
||||
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||
|
||||
// these three fields are not guarded by mu
|
||||
udpAddrPool sync.Pool
|
||||
ipv4MsgsPool sync.Pool
|
||||
ipv6MsgsPool sync.Pool
|
||||
|
||||
blackhole4 bool
|
||||
blackhole6 bool
|
||||
|
||||
listenAddr4 string
|
||||
listenAddr6 string
|
||||
bindV4 bool
|
||||
bindV6 bool
|
||||
reusePort bool
|
||||
}
|
||||
|
||||
func newStdNetBind() *StdNetBind {
|
||||
return &StdNetBind{
|
||||
udpAddrPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &net.UDPAddr{
|
||||
IP: make([]byte, 16),
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
ipv4MsgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv4.Message, IdealBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, srcControlSize)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
|
||||
ipv6MsgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv6.Message, IdealBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, srcControlSize)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
bindV4: true,
|
||||
bindV6: true,
|
||||
reusePort: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewStdNetBind creates a bind that listens on all interfaces.
|
||||
func NewStdNetBind() *StdNetBind {
|
||||
return newStdNetBind()
|
||||
}
|
||||
|
||||
// NewStdNetBindForAddr creates a bind that listens on a specific address.
|
||||
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
|
||||
// IPv6 socket will be created.
|
||||
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
|
||||
b := newStdNetBind()
|
||||
if addr.IsValid() {
|
||||
if addr.IsUnspecified() {
|
||||
// keep dual-stack defaults with empty listen addresses
|
||||
} else if addr.Is4() {
|
||||
b.listenAddr4 = addr.Unmap().String()
|
||||
b.bindV4 = true
|
||||
b.bindV6 = false
|
||||
} else {
|
||||
b.listenAddr6 = addr.Unmap().String()
|
||||
b.bindV6 = true
|
||||
b.bindV4 = false
|
||||
}
|
||||
}
|
||||
b.reusePort = reusePort
|
||||
return b
|
||||
}
|
||||
|
||||
type StdNetEndpoint struct {
|
||||
// AddrPort is the endpoint destination.
|
||||
netip.AddrPort
|
||||
// src is the current sticky source address and interface index, if supported.
|
||||
src struct {
|
||||
netip.Addr
|
||||
ifidx int32
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
_ Endpoint = &StdNetEndpoint{}
|
||||
)
|
||||
|
||||
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StdNetEndpoint{
|
||||
AddrPort: e,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) ClearSrc() {
|
||||
e.src.ifidx = 0
|
||||
e.src.Addr = netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||
return e.AddrPort.Addr()
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
return e.src.Addr
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
return e.src.ifidx
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||
b, _ := e.AddrPort.MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstToString() string {
|
||||
return e.AddrPort.String()
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return e.src.Addr.String()
|
||||
}
|
||||
|
||||
func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
|
||||
lc := listenConfig()
|
||||
if s.reusePort {
|
||||
base := lc.Control
|
||||
lc.Control = func(network, address string, c syscall.RawConn) error {
|
||||
if base != nil {
|
||||
if err := base(network, address, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
addr := ":" + strconv.Itoa(port)
|
||||
if host != "" {
|
||||
addr = net.JoinHostPort(host, strconv.Itoa(port))
|
||||
}
|
||||
|
||||
conn, err := lc.ListenPacket(context.Background(), network, addr)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Retrieve port.
|
||||
laddr := conn.LocalAddr()
|
||||
uaddr, err := net.ResolveUDPAddr(
|
||||
laddr.Network(),
|
||||
laddr.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
|
||||
if !s.bindV4 {
|
||||
return nil, nil, port, nil
|
||||
}
|
||||
host := s.listenAddr4
|
||||
conn, actualPort, err := s.listenNet("udp4", host, port)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, nil, port, nil
|
||||
}
|
||||
return nil, nil, port, err
|
||||
}
|
||||
if runtime.GOOS != "linux" {
|
||||
return conn, nil, actualPort, nil
|
||||
}
|
||||
pc := ipv4.NewPacketConn(conn)
|
||||
return conn, pc, actualPort, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
|
||||
if !s.bindV6 {
|
||||
return nil, nil, port, nil
|
||||
}
|
||||
host := s.listenAddr6
|
||||
conn, actualPort, err := s.listenNet("udp6", host, port)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, nil, port, nil
|
||||
}
|
||||
return nil, nil, port, err
|
||||
}
|
||||
if runtime.GOOS != "linux" {
|
||||
return conn, nil, actualPort, nil
|
||||
}
|
||||
pc := ipv6.NewPacketConn(conn)
|
||||
return conn, pc, actualPort, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var tries int
|
||||
|
||||
if s.ipv4 != nil || s.ipv6 != nil {
|
||||
return nil, 0, ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
||||
// If uport is 0, we can retry on failure.
|
||||
again:
|
||||
port := int(uport)
|
||||
var v4conn *net.UDPConn
|
||||
var v6conn *net.UDPConn
|
||||
var v4pc *ipv4.PacketConn
|
||||
var v6pc *ipv6.PacketConn
|
||||
|
||||
v4conn, v4pc, port, err = s.openIPv4(port)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Listen on the same port as we're using for ipv4.
|
||||
v6conn, v6pc, port, err = s.openIPv6(port)
|
||||
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||
if v4conn != nil {
|
||||
v4conn.Close()
|
||||
}
|
||||
tries++
|
||||
goto again
|
||||
}
|
||||
if err != nil {
|
||||
if v4conn != nil {
|
||||
v4conn.Close()
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var fns []ReceiveFunc
|
||||
if v4conn != nil {
|
||||
s.ipv4 = v4conn
|
||||
if v4pc != nil {
|
||||
s.ipv4PC = v4pc
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
|
||||
}
|
||||
if v6conn != nil {
|
||||
s.ipv6 = v6conn
|
||||
if v6pc != nil {
|
||||
s.ipv6PC = v6pc
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
|
||||
}
|
||||
if len(fns) == 0 {
|
||||
return nil, 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
|
||||
return fns, uint16(port), nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer s.ipv4MsgsPool.Put(msgs)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
}
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" && pc != nil {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
sizes[i] = msg.N
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
||||
defer s.ipv6MsgsPool.Put(msgs)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
}
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" && pc != nil {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
sizes[i] = msg.N
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||
// rename the IdealBatchSize constant to BatchSize.
|
||||
func (s *StdNetBind) BatchSize() int {
|
||||
if runtime.GOOS == "linux" {
|
||||
return IdealBatchSize
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err1, err2 error
|
||||
if s.ipv4 != nil {
|
||||
err1 = s.ipv4.Close()
|
||||
s.ipv4 = nil
|
||||
s.ipv4PC = nil
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
err2 = s.ipv6.Close()
|
||||
s.ipv6 = nil
|
||||
s.ipv6PC = nil
|
||||
}
|
||||
s.blackhole4 = false
|
||||
s.blackhole6 = false
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||
s.mu.Lock()
|
||||
blackhole := s.blackhole4
|
||||
conn := s.ipv4
|
||||
var (
|
||||
pc4 *ipv4.PacketConn
|
||||
pc6 *ipv6.PacketConn
|
||||
)
|
||||
is6 := false
|
||||
if endpoint.DstIP().Is6() {
|
||||
blackhole = s.blackhole6
|
||||
conn = s.ipv6
|
||||
pc6 = s.ipv6PC
|
||||
is6 = true
|
||||
} else {
|
||||
pc4 = s.ipv4PC
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if blackhole {
|
||||
return nil
|
||||
}
|
||||
if conn == nil {
|
||||
return syscall.EAFNOSUPPORT
|
||||
}
|
||||
if is6 {
|
||||
return s.send6(conn, pc6, endpoint, bufs)
|
||||
} else {
|
||||
return s.send4(conn, pc4, endpoint, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
|
||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||
as4 := ep.DstIP().As4()
|
||||
copy(ua.IP, as4[:])
|
||||
ua.IP = ua.IP[:4]
|
||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
for i, buf := range bufs {
|
||||
(*msgs)[i].Buffers[0] = buf
|
||||
(*msgs)[i].Addr = ua
|
||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
||||
}
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
start int
|
||||
)
|
||||
if runtime.GOOS == "linux" && pc != nil {
|
||||
for {
|
||||
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
for j := start; j < len(bufs); j++ {
|
||||
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
|
||||
if werr != nil {
|
||||
err = werr
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if n == len((*msgs)[start:len(bufs)]) {
|
||||
break
|
||||
}
|
||||
start += n
|
||||
}
|
||||
} else {
|
||||
for i, buf := range bufs {
|
||||
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.udpAddrPool.Put(ua)
|
||||
s.ipv4MsgsPool.Put(msgs)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
|
||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||
as16 := ep.DstIP().As16()
|
||||
copy(ua.IP, as16[:])
|
||||
ua.IP = ua.IP[:16]
|
||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
||||
for i, buf := range bufs {
|
||||
(*msgs)[i].Buffers[0] = buf
|
||||
(*msgs)[i].Addr = ua
|
||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
||||
}
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
start int
|
||||
)
|
||||
if runtime.GOOS == "linux" && pc != nil {
|
||||
for {
|
||||
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
for j := start; j < len(bufs); j++ {
|
||||
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
|
||||
if werr != nil {
|
||||
err = werr
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if n == len((*msgs)[start:len(bufs)]) {
|
||||
break
|
||||
}
|
||||
start += n
|
||||
}
|
||||
} else {
|
||||
for i, buf := range bufs {
|
||||
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.udpAddrPool.Put(ua)
|
||||
s.ipv6MsgsPool.Put(msgs)
|
||||
return err
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
||||
)
|
||||
|
||||
// A ReceiveFunc receives at least one packet from the network and writes them
|
||||
// into packets. On a successful read it returns the number of elements of
|
||||
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
||||
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
||||
// and eps slice with a length greater than or equal to the length of packets.
|
||||
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
||||
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
||||
|
||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||
//
|
||||
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
||||
// depending on the platform-specific implementation.
|
||||
type Bind interface {
|
||||
// Open puts the Bind into a listening state on a given port and reports the actual
|
||||
// port that it bound to. Passing zero results in a random selection.
|
||||
// fns is the set of functions that will be called to receive packets.
|
||||
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
|
||||
|
||||
// Close closes the Bind listener.
|
||||
// All fns returned by Open must return net.ErrClosed after a call to Close.
|
||||
Close() error
|
||||
|
||||
// SetMark sets the mark for each packet sent through this Bind.
|
||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||
SetMark(mark uint32) error
|
||||
|
||||
// Send writes one or more packets in bufs to address ep. The length of
|
||||
// bufs must not exceed BatchSize().
|
||||
Send(bufs [][]byte, ep Endpoint) error
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string.
|
||||
ParseEndpoint(s string) (Endpoint, error)
|
||||
|
||||
// BatchSize is the number of buffers expected to be passed to
|
||||
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
// BindSocketToInterface is implemented by Bind objects that support being
|
||||
// tied to a single network interface. Used by wireguard-windows.
|
||||
type BindSocketToInterface interface {
|
||||
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
|
||||
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
|
||||
}
|
||||
|
||||
// PeekLookAtSocketFd is implemented by Bind objects that support having their
|
||||
// file descriptor peeked at. Used by wireguard-android.
|
||||
type PeekLookAtSocketFd interface {
|
||||
PeekLookAtSocketFd4() (fd int, err error)
|
||||
PeekLookAtSocketFd6() (fd int, err error)
|
||||
}
|
||||
|
||||
// An Endpoint maintains the source/destination caching for a peer.
|
||||
//
|
||||
// dst: the remote address of a peer ("endpoint" in uapi terminology)
|
||||
// src: the local address from which datagrams originate going to the peer
|
||||
type Endpoint interface {
|
||||
ClearSrc() // clears the source address
|
||||
SrcToString() string // returns the local source address (ip:port)
|
||||
DstToString() string // returns the destination address (ip:port)
|
||||
DstToBytes() []byte // used for mac2 cookie calculations
|
||||
DstIP() netip.Addr
|
||||
SrcIP() netip.Addr
|
||||
}
|
||||
|
||||
var (
|
||||
ErrBindAlreadyOpen = errors.New("bind is already open")
|
||||
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
|
||||
)
|
||||
|
||||
func (fn ReceiveFunc) PrettyName() string {
|
||||
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
||||
// 0. cheese/taco.beansIPv6.func12.func21218-fm
|
||||
name = strings.TrimSuffix(name, "-fm")
|
||||
// 1. cheese/taco.beansIPv6.func12.func21218
|
||||
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
|
||||
name = name[idx+1:]
|
||||
// 2. taco.beansIPv6.func12.func21218
|
||||
}
|
||||
for {
|
||||
var idx int
|
||||
for idx = len(name) - 1; idx >= 0; idx-- {
|
||||
if name[idx] < '0' || name[idx] > '9' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx == len(name)-1 {
|
||||
break
|
||||
}
|
||||
const dotFunc = ".func"
|
||||
if !strings.HasSuffix(name[:idx+1], dotFunc) {
|
||||
break
|
||||
}
|
||||
name = name[:idx+1-len(dotFunc)]
|
||||
// 3. taco.beansIPv6.func12
|
||||
// 4. taco.beansIPv6
|
||||
}
|
||||
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
|
||||
name = name[idx+1:]
|
||||
// 5. beansIPv6
|
||||
}
|
||||
if name == "" {
|
||||
return fmt.Sprintf("%p", fn)
|
||||
}
|
||||
if strings.HasSuffix(name, "IPv4") {
|
||||
return "v4"
|
||||
}
|
||||
if strings.HasSuffix(name, "IPv6") {
|
||||
return "v6"
|
||||
}
|
||||
return name
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
||||
// the max supported by a default configuration of macOS. Some platforms will
|
||||
// silently clamp the value to other maximums, such as linux clamping to
|
||||
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
||||
// around this limitation)
|
||||
const socketBufferSize = 7 << 20
|
||||
|
||||
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||
// It is used to apply platform specific configuration to the socket prior to
|
||||
// bind.
|
||||
type controlFn func(network, address string, c syscall.RawConn) error
|
||||
|
||||
// controlFns is a list of functions that are called from the listen config
|
||||
// that can apply socket options.
|
||||
var controlFns = []controlFn{}
|
||||
|
||||
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||
// information OOB configuration for sticky sockets.
|
||||
func listenConfig() *net.ListenConfig {
|
||||
return &net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
for _, fn := range controlFns {
|
||||
if err := fn(network, address, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
|
||||
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
||||
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
||||
// fail silently - the result of failure is lower performance on very fast
|
||||
// links or high latency links.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
// Set up to *mem_max
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||
// Set beyond *mem_max if CAP_NET_ADMIN
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
||||
})
|
||||
},
|
||||
|
||||
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
switch network {
|
||||
case "udp4":
|
||||
if runtime.GOOS != "android" {
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||
})
|
||||
}
|
||||
case "udp6":
|
||||
c.Control(func(fd uintptr) {
|
||||
if runtime.GOOS != "android" {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
default:
|
||||
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||
}
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
func NewDefaultBind() Bind { return NewStdNetBind() }
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func errShouldDisableUDPGSO(err error) bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func errShouldDisableUDPGSO(err error) bool {
|
||||
var serr *os.SyscallError
|
||||
if errors.As(err, &serr) {
|
||||
// EIO is returned by udp_send_skb() if the device driver does not have
|
||||
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
||||
// See:
|
||||
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||
return serr.Err == unix.EIO
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net"
|
||||
|
||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
return
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
rc, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = rc.Control(func(fd uintptr) {
|
||||
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||
txOffload = errSyscall == nil
|
||||
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
|
||||
rxOffload = errSyscall == nil && opt == 1
|
||||
})
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
return txOffload, rxOffload
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
}
|
||||
|
||||
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
||||
// offloading control data.
|
||||
const gsoControlSize = 0
|
||||
@@ -1,65 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
sizeOfGSOData = 2
|
||||
)
|
||||
|
||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||
}
|
||||
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
||||
var gso uint16
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
||||
return int(gso), nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
||||
// data in control untouched.
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
existingLen := len(*control)
|
||||
avail := cap(*control) - existingLen
|
||||
space := unix.CmsgSpace(sizeOfGSOData)
|
||||
if avail < space {
|
||||
return
|
||||
}
|
||||
*control = (*control)[:cap(*control)]
|
||||
gsoControl := (*control)[existingLen:]
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
||||
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
||||
*control = (*control)[:existingLen+space]
|
||||
}
|
||||
|
||||
// gsoControlSize returns the recommended buffer size for pooling UDP
|
||||
// offloading control data.
|
||||
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
||||
@@ -1,64 +0,0 @@
|
||||
//go:build linux || openbsd || freebsd
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var fwmarkIoctl int
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
fwmarkIoctl = 36 /* unix.SO_MARK */
|
||||
case "freebsd":
|
||||
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
|
||||
case "openbsd":
|
||||
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||
var operr error
|
||||
if fwmarkIoctl == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.ipv4 != nil {
|
||||
fd, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fd.Control(func(fd uintptr) {
|
||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||
})
|
||||
if err == nil {
|
||||
err = operr
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
fd, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fd.Control(func(fd uintptr) {
|
||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||
})
|
||||
if err == nil {
|
||||
err = operr
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
//go:build !linux || android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net/netip"
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
||||
// {get,set}srcControl feature set, but use alternatively named flags and need
|
||||
// ports and require testing.
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||
// offloading control data.
|
||||
const stickyControlSize = 0
|
||||
|
||||
const StdNetSupportsStickySockets = false
|
||||
@@ -1,116 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
ep.ClearSrc()
|
||||
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem []byte = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IP &&
|
||||
hdr.Type == unix.IP_PKTINFO {
|
||||
|
||||
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
|
||||
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
|
||||
ep.src.ifidx = info.Ifindex
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||
hdr.Type == unix.IPV6_PKTINFO {
|
||||
|
||||
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
|
||||
ep.src.Addr = netip.AddrFrom16(info.Addr)
|
||||
ep.src.ifidx = int32(info.Ifindex)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
|
||||
// panics if buf is of insufficient size.
|
||||
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
|
||||
size := int(unsafe.Sizeof(t))
|
||||
if len(buf) < size {
|
||||
panic("pktInfoFromBuf: buffer too small")
|
||||
}
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
|
||||
return t
|
||||
}
|
||||
|
||||
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
||||
// and source ifindex found in ep. control's len will be set to 0 in the event
|
||||
// that ep is a default value.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
*control = (*control)[:cap(*control)]
|
||||
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
|
||||
*control = (*control)[:0]
|
||||
return
|
||||
}
|
||||
|
||||
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
|
||||
*control = (*control)[:0]
|
||||
return
|
||||
}
|
||||
|
||||
if len(*control) < srcControlSize {
|
||||
*control = (*control)[:0]
|
||||
return
|
||||
}
|
||||
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
|
||||
if ep.SrcIP().Is4() {
|
||||
hdr.Level = unix.IPPROTO_IP
|
||||
hdr.Type = unix.IP_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
||||
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
||||
info.Ifindex = ep.src.ifidx
|
||||
if ep.SrcIP().IsValid() {
|
||||
info.Spec_dst = ep.SrcIP().As4()
|
||||
}
|
||||
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
||||
} else {
|
||||
hdr.Level = unix.IPPROTO_IPV6
|
||||
hdr.Type = unix.IPV6_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
|
||||
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
||||
info.Ifindex = uint32(ep.src.ifidx)
|
||||
if ep.SrcIP().IsValid() {
|
||||
info.Addr = ep.SrcIP().As16()
|
||||
}
|
||||
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
||||
|
||||
const StdNetSupportsStickySockets = true
|
||||
@@ -1,42 +0,0 @@
|
||||
package tun
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// TODO: Explore SIMD and/or other assembly optimizations.
|
||||
func checksumNoFold(b []byte, initial uint64) uint64 {
|
||||
ac := initial
|
||||
i := 0
|
||||
n := len(b)
|
||||
for n >= 4 {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
|
||||
n -= 4
|
||||
i += 4
|
||||
}
|
||||
for n >= 2 {
|
||||
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
|
||||
n -= 2
|
||||
i += 2
|
||||
}
|
||||
if n == 1 {
|
||||
ac += uint64(b[i]) << 8
|
||||
}
|
||||
return ac
|
||||
}
|
||||
|
||||
func checksum(b []byte, initial uint64) uint16 {
|
||||
ac := checksumNoFold(b, initial)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
return uint16(ac)
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
|
||||
sum := checksumNoFold(srcAddr, 0)
|
||||
sum = checksumNoFold(dstAddr, sum)
|
||||
sum = checksumNoFold([]byte{0, protocol}, sum)
|
||||
tmp := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(tmp, totalLen)
|
||||
return checksumNoFold(tmp, sum)
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
package tun
|
||||
|
||||
const VirtioNetHdrLen = virtioNetHdrLen
|
||||
@@ -1,630 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"unsafe"
|
||||
|
||||
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var ErrTooManySegments = errors.New("tun: too many segments for TSO")
|
||||
|
||||
const tcpFlagsOffset = 13
|
||||
|
||||
const (
|
||||
tcpFlagFIN uint8 = 0x01
|
||||
tcpFlagPSH uint8 = 0x08
|
||||
tcpFlagACK uint8 = 0x10
|
||||
)
|
||||
|
||||
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
||||
// kernel symbol is virtio_net_hdr.
|
||||
type virtioNetHdr struct {
|
||||
flags uint8
|
||||
gsoType uint8
|
||||
hdrLen uint16
|
||||
gsoSize uint16
|
||||
csumStart uint16
|
||||
csumOffset uint16
|
||||
}
|
||||
|
||||
func (v *virtioNetHdr) decode(b []byte) error {
|
||||
if len(b) < virtioNetHdrLen {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *virtioNetHdr) encode(b []byte) error {
|
||||
if len(b) < virtioNetHdrLen {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
||||
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
||||
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
||||
)
|
||||
|
||||
// flowKey represents the key for a flow.
|
||||
type flowKey struct {
|
||||
srcAddr, dstAddr [16]byte
|
||||
srcPort, dstPort uint16
|
||||
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
||||
}
|
||||
|
||||
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
|
||||
type tcpGROTable struct {
|
||||
itemsByFlow map[flowKey][]tcpGROItem
|
||||
itemsPool [][]tcpGROItem
|
||||
}
|
||||
|
||||
func newTCPGROTable() *tcpGROTable {
|
||||
t := &tcpGROTable{
|
||||
itemsByFlow: make(map[flowKey][]tcpGROItem, wgconn.IdealBatchSize),
|
||||
itemsPool: make([][]tcpGROItem, wgconn.IdealBatchSize),
|
||||
}
|
||||
for i := range t.itemsPool {
|
||||
t.itemsPool[i] = make([]tcpGROItem, 0, wgconn.IdealBatchSize)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
|
||||
key := flowKey{}
|
||||
addrSize := dstAddr - srcAddr
|
||||
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
|
||||
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
|
||||
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
||||
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
||||
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
||||
return key
|
||||
}
|
||||
|
||||
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||
// returning the packets found for the flow, or inserting a new one if none
|
||||
// is found.
|
||||
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||
items, ok := t.itemsByFlow[key]
|
||||
if ok {
|
||||
return items, ok
|
||||
}
|
||||
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// insert an item in the table for the provided packet and packet metadata.
|
||||
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||
item := tcpGROItem{
|
||||
key: key,
|
||||
bufsIndex: uint16(bufsIndex),
|
||||
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
||||
iphLen: uint8(tcphOffset),
|
||||
tcphLen: uint8(tcphLen),
|
||||
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
||||
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
||||
}
|
||||
items, ok := t.itemsByFlow[key]
|
||||
if !ok {
|
||||
items = t.newItems()
|
||||
}
|
||||
items = append(items, item)
|
||||
t.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
||||
items, _ := t.itemsByFlow[item.key]
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
|
||||
items, _ := t.itemsByFlow[key]
|
||||
items = append(items[:i], items[i+1:]...)
|
||||
t.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
||||
// of a GRO evaluation across a vector of packets.
|
||||
type tcpGROItem struct {
|
||||
key flowKey
|
||||
sentSeq uint32 // the sequence number
|
||||
bufsIndex uint16 // the index into the original bufs slice
|
||||
numMerged uint16 // the number of packets merged into this item
|
||||
gsoSize uint16 // payload size
|
||||
iphLen uint8 // ip header len
|
||||
tcphLen uint8 // tcp header len
|
||||
pshSet bool // psh flag is set
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) newItems() []tcpGROItem {
|
||||
var items []tcpGROItem
|
||||
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
||||
return items
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) reset() {
|
||||
for k, items := range t.itemsByFlow {
|
||||
items = items[:0]
|
||||
t.itemsPool = append(t.itemsPool, items)
|
||||
delete(t.itemsByFlow, k)
|
||||
}
|
||||
}
|
||||
|
||||
// canCoalesce represents the outcome of checking if two TCP packets are
|
||||
// candidates for coalescing.
|
||||
type canCoalesce int
|
||||
|
||||
const (
|
||||
coalescePrepend canCoalesce = -1
|
||||
coalesceUnavailable canCoalesce = 0
|
||||
coalesceAppend canCoalesce = 1
|
||||
)
|
||||
|
||||
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||
// described by item. This function makes considerations that match the kernel's
|
||||
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
||||
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||
if tcphLen != item.tcphLen {
|
||||
// cannot coalesce with unequal tcp options len
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if tcphLen > 20 {
|
||||
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
||||
// cannot coalesce with unequal tcp options
|
||||
return coalesceUnavailable
|
||||
}
|
||||
}
|
||||
if pkt[0]>>4 == 6 {
|
||||
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
|
||||
// cannot coalesce with unequal Traffic class values
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if pkt[7] != pktTarget[7] {
|
||||
// cannot coalesce with unequal Hop limit values
|
||||
return coalesceUnavailable
|
||||
}
|
||||
} else {
|
||||
if pkt[1] != pktTarget[1] {
|
||||
// cannot coalesce with unequal ToS values
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if pkt[6]>>5 != pktTarget[6]>>5 {
|
||||
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
||||
// further up the stack.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if pkt[8] != pktTarget[8] {
|
||||
// cannot coalesce with unequal TTL values
|
||||
return coalesceUnavailable
|
||||
}
|
||||
}
|
||||
// seq adjacency
|
||||
lhsLen := item.gsoSize
|
||||
lhsLen += item.numMerged * item.gsoSize
|
||||
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
||||
if item.pshSet {
|
||||
// We cannot append to a segment that has the PSH flag set, PSH
|
||||
// can only be set on the final segment in a reassembled group.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
||||
// A smaller than gsoSize packet has been appended previously.
|
||||
// Nothing can come after a smaller packet on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalesceAppend
|
||||
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
||||
if pshSet {
|
||||
// We cannot prepend with a segment that has the PSH flag set, PSH
|
||||
// can only be set on the final segment in a reassembled group.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize < item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
||||
// There's at least one previous merge, and we're larger than all
|
||||
// previous. This would put multiple smaller packets on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalescePrepend
|
||||
}
|
||||
return coalesceUnavailable
|
||||
}
|
||||
|
||||
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
|
||||
srcAddrAt := ipv4SrcAddrOffset
|
||||
addrSize := 4
|
||||
if isV6 {
|
||||
srcAddrAt = ipv6SrcAddrOffset
|
||||
addrSize = 16
|
||||
}
|
||||
tcpTotalLen := uint16(len(pkt) - int(iphLen))
|
||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
|
||||
return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
|
||||
}
|
||||
|
||||
// coalesceResult represents the result of attempting to coalesce two TCP
|
||||
// packets.
|
||||
type coalesceResult int
|
||||
|
||||
const (
|
||||
coalesceInsufficientCap coalesceResult = 0
|
||||
coalescePSHEnding coalesceResult = 1
|
||||
coalesceItemInvalidCSum coalesceResult = 2
|
||||
coalescePktInvalidCSum coalesceResult = 3
|
||||
coalesceSuccess coalesceResult = 4
|
||||
)
|
||||
|
||||
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
||||
// item, returning the outcome. This function may swap bufs elements in the
|
||||
// event of a prepend as item's bufs index is already being tracked for writing
|
||||
// to a Device.
|
||||
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||
var pktHead []byte // the packet that will end up at the front
|
||||
headersLen := item.iphLen + item.tcphLen
|
||||
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||
|
||||
// Copy data
|
||||
if mode == coalescePrepend {
|
||||
pktHead = pkt
|
||||
if cap(pkt)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if pshSet {
|
||||
return coalescePSHEnding
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
item.sentSeq = seq
|
||||
extendBy := coalescedLen - len(pktHead)
|
||||
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
||||
// Flip the slice headers in bufs as part of prepend. The index of item
|
||||
// is already being tracked for writing.
|
||||
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
||||
} else {
|
||||
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
||||
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
if pshSet {
|
||||
// We are appending a segment with PSH set.
|
||||
item.pshSet = pshSet
|
||||
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
||||
}
|
||||
extendBy := len(pkt) - int(headersLen)
|
||||
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||
}
|
||||
|
||||
if gsoSize > item.gsoSize {
|
||||
item.gsoSize = gsoSize
|
||||
}
|
||||
hdr := virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||
hdrLen: uint16(headersLen),
|
||||
gsoSize: uint16(item.gsoSize),
|
||||
csumStart: uint16(item.iphLen),
|
||||
csumOffset: 16,
|
||||
}
|
||||
|
||||
// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
|
||||
// (IPv4) header checksum.
|
||||
if isV6 {
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||
binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
|
||||
} else {
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||
pktHead[10], pktHead[11] = 0, 0 // clear checksum field
|
||||
binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
|
||||
iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
|
||||
binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
|
||||
}
|
||||
hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
|
||||
|
||||
// Calculate the pseudo header checksum and place it at the TCP checksum
|
||||
// offset. Downstream checksum offloading will combine this with computation
|
||||
// of the tcp header and payload checksum.
|
||||
addrLen := 4
|
||||
addrOffset := ipv4SrcAddrOffset
|
||||
if isV6 {
|
||||
addrLen = 16
|
||||
addrOffset = ipv6SrcAddrOffset
|
||||
}
|
||||
srcAddrAt := bufsOffset + addrOffset
|
||||
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
|
||||
binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
|
||||
|
||||
item.numMerged++
|
||||
return coalesceSuccess
|
||||
}
|
||||
|
||||
const (
|
||||
ipv4FlagMoreFragments uint8 = 0x20
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4SrcAddrOffset = 12
|
||||
ipv6SrcAddrOffset = 8
|
||||
maxUint16 = 1<<16 - 1
|
||||
)
|
||||
|
||||
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
||||
// existing packets tracked in table. It will return false when pktI is not
|
||||
// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
|
||||
// should be written to the Device.
|
||||
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
|
||||
pkt := bufs[pktI][offset:]
|
||||
if len(pkt) > maxUint16 {
|
||||
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||
return false
|
||||
}
|
||||
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||
if isV6 {
|
||||
iphLen = 40
|
||||
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||
if totalLen != len(pkt) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(pkt) < iphLen {
|
||||
return false
|
||||
}
|
||||
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
||||
if tcphLen < 20 || tcphLen > 60 {
|
||||
return false
|
||||
}
|
||||
if len(pkt) < iphLen+tcphLen {
|
||||
return false
|
||||
}
|
||||
if !isV6 {
|
||||
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||
// no GRO support for fragmented segments for now
|
||||
return false
|
||||
}
|
||||
}
|
||||
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
||||
var pshSet bool
|
||||
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
||||
if tcpFlags != tcpFlagACK {
|
||||
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
||||
return false
|
||||
}
|
||||
pshSet = true
|
||||
}
|
||||
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
||||
// not a candidate if payload len is 0
|
||||
if gsoSize < 1 {
|
||||
return false
|
||||
}
|
||||
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
||||
srcAddrOffset := ipv4SrcAddrOffset
|
||||
addrLen := 4
|
||||
if isV6 {
|
||||
srcAddrOffset = ipv6SrcAddrOffset
|
||||
addrLen = 16
|
||||
}
|
||||
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||
if !existing {
|
||||
return false
|
||||
}
|
||||
for i := len(items) - 1; i >= 0; i-- {
|
||||
// In the best case of packets arriving in order iterating in reverse is
|
||||
// more efficient if there are multiple items for a given flow. This
|
||||
// also enables a natural table.deleteAt() in the
|
||||
// coalesceItemInvalidCSum case without the need for index tracking.
|
||||
// This algorithm makes a best effort to coalesce in the event of
|
||||
// unordered packets, where pkt may land anywhere in items from a
|
||||
// sequence number perspective, however once an item is inserted into
|
||||
// the table it is never compared across other items later.
|
||||
item := items[i]
|
||||
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
||||
if can != coalesceUnavailable {
|
||||
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
||||
switch result {
|
||||
case coalesceSuccess:
|
||||
table.updateAt(item, i)
|
||||
return true
|
||||
case coalesceItemInvalidCSum:
|
||||
// delete the item with an invalid csum
|
||||
table.deleteAt(item.key, i)
|
||||
case coalescePktInvalidCSum:
|
||||
// no point in inserting an item that we can't coalesce
|
||||
return false
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
// failed to coalesce with any other packets; store the item in the flow
|
||||
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||
return false
|
||||
}
|
||||
|
||||
func isTCP4NoIPOptions(b []byte) bool {
|
||||
if len(b) < 40 {
|
||||
return false
|
||||
}
|
||||
if b[0]>>4 != 4 {
|
||||
return false
|
||||
}
|
||||
if b[0]&0x0F != 5 {
|
||||
return false
|
||||
}
|
||||
if b[9] != unix.IPPROTO_TCP {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isTCP6NoEH(b []byte) bool {
|
||||
if len(b) < 60 {
|
||||
return false
|
||||
}
|
||||
if b[0]>>4 != 6 {
|
||||
return false
|
||||
}
|
||||
if b[6] != unix.IPPROTO_TCP {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
||||
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
|
||||
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
||||
// and recycle them across vectors of packets.
|
||||
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
|
||||
for i := range bufs {
|
||||
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
||||
return errors.New("invalid offset")
|
||||
}
|
||||
var coalesced bool
|
||||
switch {
|
||||
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
|
||||
coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
|
||||
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
|
||||
coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
|
||||
}
|
||||
if !coalesced {
|
||||
hdr := virtioNetHdr{}
|
||||
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*toWrite = append(*toWrite, i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tcpTSO splits packets from in into outBuffs, writing the size of each
|
||||
// element into sizes. It returns the number of buffers populated, and/or an
|
||||
// error.
|
||||
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
|
||||
iphLen := int(hdr.csumStart)
|
||||
srcAddrOffset := ipv6SrcAddrOffset
|
||||
addrLen := 16
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||
in[10], in[11] = 0, 0 // clear ipv4 header checksum
|
||||
srcAddrOffset = ipv4SrcAddrOffset
|
||||
addrLen = 4
|
||||
}
|
||||
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
|
||||
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
|
||||
nextSegmentDataAt := int(hdr.hdrLen)
|
||||
i := 0
|
||||
for ; nextSegmentDataAt < len(in); i++ {
|
||||
if i == len(outBuffs) {
|
||||
return i - 1, ErrTooManySegments
|
||||
}
|
||||
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
|
||||
if nextSegmentEnd > len(in) {
|
||||
nextSegmentEnd = len(in)
|
||||
}
|
||||
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
||||
totalLen := int(hdr.hdrLen) + segmentDataLen
|
||||
sizes[i] = totalLen
|
||||
out := outBuffs[i][outOffset:]
|
||||
|
||||
copy(out, in[:iphLen])
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||
// For IPv4 we are responsible for incrementing the ID field,
|
||||
// updating the total len field, and recalculating the header
|
||||
// checksum.
|
||||
if i > 0 {
|
||||
id := binary.BigEndian.Uint16(out[4:])
|
||||
id += uint16(i)
|
||||
binary.BigEndian.PutUint16(out[4:], id)
|
||||
}
|
||||
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
||||
ipv4CSum := ^checksum(out[:iphLen], 0)
|
||||
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
||||
} else {
|
||||
// For IPv6 we are responsible for updating the payload length field.
|
||||
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
||||
}
|
||||
|
||||
// TCP header
|
||||
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
|
||||
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
|
||||
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
|
||||
if nextSegmentEnd != len(in) {
|
||||
// FIN and PSH should only be set on last segment
|
||||
clearFlags := tcpFlagFIN | tcpFlagPSH
|
||||
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
|
||||
}
|
||||
|
||||
// payload
|
||||
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
||||
|
||||
// TCP checksum
|
||||
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
|
||||
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
|
||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
|
||||
tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
|
||||
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
|
||||
|
||||
nextSegmentDataAt += int(hdr.gsoSize)
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
|
||||
cSumAt := cSumStart + cSumOffset
|
||||
// The initial value at the checksum offset should be summed with the
|
||||
// checksum we compute. This is typically the pseudo-header checksum.
|
||||
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
||||
in[cSumAt], in[cSumAt+1] = 0, 0
|
||||
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
|
||||
return nil
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
type Event int
|
||||
|
||||
const (
|
||||
EventUp = 1 << iota
|
||||
EventDown
|
||||
EventMTUUpdate
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
// File returns the file descriptor of the device.
|
||||
File() *os.File
|
||||
|
||||
// Read one or more packets from the Device (without any additional headers).
|
||||
// On a successful read it returns the number of packets read, and sets
|
||||
// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
|
||||
// A nonzero offset can be used to instruct the Device on where to begin
|
||||
// reading into each element of the bufs slice.
|
||||
Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
|
||||
|
||||
// Write one or more packets to the device (without any additional headers).
|
||||
// On a successful write it returns the number of packets written. A nonzero
|
||||
// offset can be used to instruct the Device on where to begin writing from
|
||||
// each packet contained within the bufs slice.
|
||||
Write(bufs [][]byte, offset int) (int, error)
|
||||
|
||||
// MTU returns the MTU of the Device.
|
||||
MTU() (int, error)
|
||||
|
||||
// Name returns the current name of the Device.
|
||||
Name() (string, error)
|
||||
|
||||
// Events returns a channel of type Event, which is fed Device events.
|
||||
Events() <-chan Event
|
||||
|
||||
// Close stops the Device and closes the Event channel.
|
||||
Close() error
|
||||
|
||||
// BatchSize returns the preferred/max number of packets that can be read or
|
||||
// written in a single read/write call. BatchSize must not change over the
|
||||
// lifetime of a Device.
|
||||
BatchSize() int
|
||||
}
|
||||
@@ -1,664 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package tun
|
||||
|
||||
/* Implementation of the TUN device interface for linux
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
const (
|
||||
cloneDevicePath = "/dev/net/tun"
|
||||
ifReqSize = unix.IFNAMSIZ + 64
|
||||
)
|
||||
|
||||
type NativeTun struct {
|
||||
tunFile *os.File
|
||||
index int32 // if index
|
||||
errors chan error // async error handling
|
||||
events chan Event // device related events
|
||||
netlinkSock int
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
hackListenerClosed sync.Mutex
|
||||
statusListenersShutdown chan struct{}
|
||||
batchSize int
|
||||
vnetHdr bool
|
||||
|
||||
closeOnce sync.Once
|
||||
|
||||
nameOnce sync.Once // guards calling initNameCache, which sets following fields
|
||||
nameCache string // name of interface
|
||||
nameErr error
|
||||
|
||||
readOpMu sync.Mutex // readOpMu guards readBuff
|
||||
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
|
||||
|
||||
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
|
||||
toWrite []int
|
||||
tcp4GROTable, tcp6GROTable *tcpGROTable
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
return tun.tunFile
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineHackListener() {
|
||||
defer tun.hackListenerClosed.Unlock()
|
||||
/* This is needed for the detection to work across network namespaces
|
||||
* If you are reading this and know a better method, please get in touch.
|
||||
*/
|
||||
last := 0
|
||||
const (
|
||||
up = 1
|
||||
down = 2
|
||||
)
|
||||
for {
|
||||
sysconn, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err2 := sysconn.Control(func(fd uintptr) {
|
||||
_, err = unix.Write(int(fd), nil)
|
||||
})
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
switch err {
|
||||
case unix.EINVAL:
|
||||
if last != up {
|
||||
// If the tunnel is up, it reports that write() is
|
||||
// allowed but we provided invalid data.
|
||||
tun.events <- EventUp
|
||||
last = up
|
||||
}
|
||||
case unix.EIO:
|
||||
if last != down {
|
||||
// If the tunnel is down, it reports that no I/O
|
||||
// is possible, without checking our provided data.
|
||||
tun.events <- EventDown
|
||||
last = down
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
// nothing
|
||||
case <-tun.statusListenersShutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createNetlinkSocket() (int, error) {
|
||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
saddr := &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
|
||||
}
|
||||
err = unix.Bind(sock, saddr)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return sock, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineNetlinkListener() {
|
||||
defer func() {
|
||||
unix.Close(tun.netlinkSock)
|
||||
tun.hackListenerClosed.Lock()
|
||||
close(tun.events)
|
||||
tun.netlinkCancel.Close()
|
||||
}()
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
var err error
|
||||
var msgn int
|
||||
for {
|
||||
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
|
||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
if !tun.netlinkCancel.ReadyRead() {
|
||||
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-tun.statusListenersShutdown:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
wasEverUp := false
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
|
||||
if int(hdr.Len) > len(remain) {
|
||||
break
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case unix.NLMSG_DONE:
|
||||
remain = []byte{}
|
||||
|
||||
case unix.RTM_NEWLINK:
|
||||
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
|
||||
remain = remain[hdr.Len:]
|
||||
|
||||
if info.Index != tun.index {
|
||||
// not our interface
|
||||
continue
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||
tun.events <- EventUp
|
||||
wasEverUp = true
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||
// Don't emit EventDown before we've ever emitted EventUp.
|
||||
// This avoids a startup race with HackListener, which
|
||||
// might detect Up before we have finished reporting Down.
|
||||
if wasEverUp {
|
||||
tun.events <- EventDown
|
||||
}
|
||||
}
|
||||
|
||||
tun.events <- EventMTUUpdate
|
||||
|
||||
default:
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getIFIndex(name string) (int32, error) {
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], name)
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCGIFINDEX),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) setMTU(n int) error {
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open datagram socket
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], name)
|
||||
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCSIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineNetlinkRead() {
|
||||
defer func() {
|
||||
unix.Close(tun.netlinkSock)
|
||||
tun.hackListenerClosed.Lock()
|
||||
close(tun.events)
|
||||
tun.netlinkCancel.Close()
|
||||
}()
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
var err error
|
||||
var msgn int
|
||||
for {
|
||||
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
|
||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
if !tun.netlinkCancel.ReadyRead() {
|
||||
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
wasEverUp := false
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
|
||||
if int(hdr.Len) > len(remain) {
|
||||
break
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case unix.NLMSG_DONE:
|
||||
remain = []byte{}
|
||||
|
||||
case unix.RTM_NEWLINK:
|
||||
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
|
||||
remain = remain[hdr.Len:]
|
||||
|
||||
if info.Index != tun.index {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||
tun.events <- EventUp
|
||||
wasEverUp = true
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||
if wasEverUp {
|
||||
tun.events <- EventDown
|
||||
}
|
||||
}
|
||||
tun.events <- EventMTUUpdate
|
||||
|
||||
default:
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineNetlink() {
|
||||
var err error
|
||||
|
||||
tun.netlinkSock, err = createNetlinkSocket()
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("failed to create netlink socket: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("failed to create netlink cancel: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
go tun.routineNetlinkListener()
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err1, err2 error
|
||||
tun.closeOnce.Do(func() {
|
||||
if tun.statusListenersShutdown != nil {
|
||||
close(tun.statusListenersShutdown)
|
||||
if tun.netlinkCancel != nil {
|
||||
err1 = tun.netlinkCancel.Cancel()
|
||||
}
|
||||
} else if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
err2 = tun.tunFile.Close()
|
||||
})
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return tun.batchSize
|
||||
}
|
||||
|
||||
const (
|
||||
// TODO: support TSO with ECN bits
|
||||
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||
)
|
||||
|
||||
func (tun *NativeTun) initFromFlags(name string) error {
|
||||
sc, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if e := sc.Control(func(fd uintptr) {
|
||||
var (
|
||||
ifr *unix.Ifreq
|
||||
)
|
||||
ifr, err = unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got := ifr.Uint16()
|
||||
if got&unix.IFF_VNET_HDR != 0 {
|
||||
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tun.vnetHdr = true
|
||||
tun.batchSize = wgconn.IdealBatchSize
|
||||
} else {
|
||||
tun.batchSize = 1
|
||||
}
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateTUN creates a Device with the provided name and MTU.
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
|
||||
}
|
||||
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
|
||||
tun, err := CreateTUNFromFile(fd, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if name != "tun" {
|
||||
if err := tun.(*NativeTun).initFromFlags(name); err != nil {
|
||||
tun.Close()
|
||||
return nil, fmt.Errorf("CreateTUN(%q) failed to set flags: %w", name, err)
|
||||
}
|
||||
}
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
|
||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
errors: make(chan error, 5),
|
||||
events: make(chan Event, 5),
|
||||
}
|
||||
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine TUN name: %w", err)
|
||||
}
|
||||
|
||||
if err := tun.initFromFlags(name); err != nil {
|
||||
return nil, fmt.Errorf("failed to query TUN flags: %w", err)
|
||||
}
|
||||
|
||||
if tun.batchSize == 0 {
|
||||
tun.batchSize = 1
|
||||
}
|
||||
|
||||
tun.index, err = getIFIndex(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get TUN index: %w", err)
|
||||
}
|
||||
|
||||
if err = tun.setMTU(mtu); err != nil {
|
||||
return nil, fmt.Errorf("failed to set MTU: %w", err)
|
||||
}
|
||||
|
||||
tun.statusListenersShutdown = make(chan struct{})
|
||||
go tun.routineNetlink()
|
||||
|
||||
if tun.batchSize == 0 {
|
||||
tun.batchSize = 1
|
||||
}
|
||||
|
||||
tun.tcp4GROTable = newTCPGROTable()
|
||||
tun.tcp6GROTable = newTCPGROTable()
|
||||
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Name() (string, error) {
|
||||
tun.nameOnce.Do(tun.initNameCache)
|
||||
return tun.nameCache, tun.nameErr
|
||||
}
|
||||
|
||||
func (tun *NativeTun) initNameCache() {
|
||||
sysconn, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
tun.nameErr = err
|
||||
return
|
||||
}
|
||||
err = sysconn.Control(func(fd uintptr) {
|
||||
var ifr [ifReqSize]byte
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
fd,
|
||||
uintptr(unix.TUNGETIFF),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
if errno != 0 {
|
||||
tun.nameErr = errno
|
||||
return
|
||||
}
|
||||
tun.nameCache = unix.ByteSliceToString(ifr[:])
|
||||
})
|
||||
if err != nil && tun.nameErr == nil {
|
||||
tun.nameErr = err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) MTU() (int, error) {
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// open datagram socket
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], name)
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCGIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
return int(*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
tun.writeOpMu.Lock()
|
||||
defer func() {
|
||||
tun.tcp4GROTable.reset()
|
||||
tun.tcp6GROTable.reset()
|
||||
tun.writeOpMu.Unlock()
|
||||
}()
|
||||
var (
|
||||
errs error
|
||||
total int
|
||||
)
|
||||
tun.toWrite = tun.toWrite[:0]
|
||||
if tun.vnetHdr {
|
||||
err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
offset -= virtioNetHdrLen
|
||||
} else {
|
||||
for i := range bufs {
|
||||
tun.toWrite = append(tun.toWrite, i)
|
||||
}
|
||||
}
|
||||
for _, bufsI := range tun.toWrite {
|
||||
n, err := tun.tunFile.Write(bufs[bufsI][offset:])
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
return total, os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
errs = errors.Join(errs, err)
|
||||
} else {
|
||||
total += n
|
||||
}
|
||||
}
|
||||
return total, errs
|
||||
}
|
||||
|
||||
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
|
||||
// each buffer. It mutates sizes to reflect the size of each element of bufs,
|
||||
// and returns the number of packets read.
|
||||
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
var hdr virtioNetHdr
|
||||
if err := hdr.decode(in); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
in = in[virtioNetHdrLen:]
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
||||
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
||||
if err := gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if len(in) > len(bufs[0][offset:]) {
|
||||
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
|
||||
}
|
||||
n := copy(bufs[0][offset:], in)
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
|
||||
}
|
||||
|
||||
ipVersion := in[0] >> 4
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||
}
|
||||
case 6:
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
||||
}
|
||||
|
||||
if len(in) <= int(hdr.csumStart+12) {
|
||||
return 0, errors.New("packet is too short")
|
||||
}
|
||||
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
|
||||
if tcpHLen < 20 || tcpHLen > 60 {
|
||||
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||
}
|
||||
hdr.hdrLen = hdr.csumStart + tcpHLen
|
||||
if len(in) < int(hdr.hdrLen) {
|
||||
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
|
||||
}
|
||||
if hdr.hdrLen < hdr.csumStart {
|
||||
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
|
||||
}
|
||||
cSumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||
if cSumAt+1 >= len(in) {
|
||||
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
||||
}
|
||||
|
||||
return tcpTSO(in, hdr, bufs, sizes, offset)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
tun.readOpMu.Lock()
|
||||
defer tun.readOpMu.Unlock()
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
readInto := bufs[0][offset:]
|
||||
if tun.vnetHdr {
|
||||
readInto = tun.readBuff[:]
|
||||
}
|
||||
n, err := tun.tunFile.Read(readInto)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tun.vnetHdr {
|
||||
return handleVirtioRead(readInto[:n], bufs, sizes, offset)
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user