Compare commits

..

11 Commits

Author SHA1 Message Date
Ryan
2d128a3254 add locking for stop crash 2025-11-05 11:58:25 -05:00
Ryan
c8980d34cf fixes 2025-11-05 10:54:08 -05:00
Ryan
98f264cf14 works well 2025-11-04 19:33:52 -05:00
Ryan
aa44f4c7c9 hmmmmmm it works i guess maybe 2025-11-04 16:08:31 -05:00
Ryan Huber
419157c407 passes traffic 2025-11-04 04:50:35 +00:00
Ryan Huber
0864852d33 updated bind 2025-11-04 04:39:07 +00:00
Ryan Huber
2b5aec9a18 updated udp 2025-11-04 04:34:59 +00:00
Ryan Huber
f0665bee20 pem.go restored 2025-11-04 04:32:22 +00:00
Ryan Huber
11da0baab1 quick fix 2025-11-04 04:21:27 +00:00
Ryan Huber
608904b9dd add new files for compat layer 2025-11-04 04:10:51 +00:00
Ryan Huber
fd1c52127f first try 2025-11-04 04:00:29 +00:00
82 changed files with 4920 additions and 1397 deletions

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -10,7 +10,7 @@ jobs:
name: Build Linux/BSD All name: Build Linux/BSD All
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release mv build/*.tar.gz release
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: linux-latest name: linux-latest
path: release path: release
@@ -33,7 +33,7 @@ jobs:
name: Build Windows name: Build Windows
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -55,7 +55,7 @@ jobs:
mv dist\windows\wintun build\dist\windows\ mv dist\windows\wintun build\dist\windows\
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: windows-latest name: windows-latest
path: build path: build
@@ -66,7 +66,7 @@ jobs:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -104,7 +104,7 @@ jobs:
fi fi
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: darwin-latest name: darwin-latest
path: ./release/* path: ./release/*
@@ -124,11 +124,11 @@ jobs:
# be overwritten # be overwritten
- name: Checkout code - name: Checkout code
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Download artifacts - name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/download-artifact@v6 uses: actions/download-artifact@v4
with: with:
name: linux-latest name: linux-latest
path: artifacts path: artifacts
@@ -160,10 +160,10 @@ jobs:
needs: [build-linux, build-darwin, build-windows] needs: [build-linux, build-darwin, build-windows]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- name: Download artifacts - name: Download artifacts
uses: actions/download-artifact@v6 uses: actions/download-artifact@v4
with: with:
path: artifacts path: artifacts

View File

@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -32,7 +32,7 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v9 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.5 version: v2.5
@@ -45,7 +45,7 @@ jobs:
- name: Build test mobile - name: Build test mobile
run: make build-test-mobile run: make build-test-mobile
- uses: actions/upload-artifact@v5 - uses: actions/upload-artifact@v4
with: with:
name: e2e packet flow linux-latest name: e2e packet flow linux-latest
path: e2e/mermaid/linux-latest path: e2e/mermaid/linux-latest
@@ -56,7 +56,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -77,7 +77,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -98,7 +98,7 @@ jobs:
os: [windows-latest, macos-latest] os: [windows-latest, macos-latest]
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -115,7 +115,7 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v9 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.5 version: v2.5
@@ -125,7 +125,7 @@ jobs:
- name: End 2 end - name: End 2 end
run: make e2evv run: make e2evv
- uses: actions/upload-artifact@v5 - uses: actions/upload-artifact@v4
with: with:
name: e2e packet flow ${{ matrix.os }} name: e2e packet flow ${{ matrix.os }}
path: e2e/mermaid/${{ matrix.os }} path: e2e/mermaid/${{ matrix.os }}

164
batch_pipeline.go Normal file
View File

@@ -0,0 +1,164 @@
package nebula
import (
"net/netip"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
)
// batchPipelines tracks whether the inside device can operate on packet batches
// and, if so, holds the shared packet pool sized for the virtio headroom and
// payload limits advertised by the device. It also owns the fan-in/fan-out
// queues between the TUN readers, encrypt/decrypt workers, and the UDP writers.
type batchPipelines struct {
enabled bool
inside overlay.BatchCapableDevice
headroom int
payloadCap int
pool *overlay.PacketPool
batchSize int
routines int
rxQueues []chan *overlay.Packet
txQueues []chan queuedDatagram
tunQueues []chan *overlay.Packet
}
type queuedDatagram struct {
packet *overlay.Packet
addr netip.AddrPort
}
func (bp *batchPipelines) init(device overlay.Device, routines int, queueDepth int, maxSegments int) {
if device == nil || routines <= 0 {
return
}
bcap, ok := device.(overlay.BatchCapableDevice)
if !ok {
return
}
headroom := bcap.BatchHeadroom()
payload := bcap.BatchPayloadCap()
if maxSegments < 1 {
maxSegments = 1
}
requiredPayload := udp.MTU * maxSegments
if payload < requiredPayload {
payload = requiredPayload
}
batchSize := bcap.BatchSize()
if headroom <= 0 || payload <= 0 || batchSize <= 0 {
return
}
bp.enabled = true
bp.inside = bcap
bp.headroom = headroom
bp.payloadCap = payload
bp.batchSize = batchSize
bp.routines = routines
bp.pool = overlay.NewPacketPool(headroom, payload)
queueCap := batchSize * defaultBatchQueueDepthFactor
if queueDepth > 0 {
queueCap = queueDepth
}
if queueCap < batchSize {
queueCap = batchSize
}
bp.rxQueues = make([]chan *overlay.Packet, routines)
bp.txQueues = make([]chan queuedDatagram, routines)
bp.tunQueues = make([]chan *overlay.Packet, routines)
for i := 0; i < routines; i++ {
bp.rxQueues[i] = make(chan *overlay.Packet, queueCap)
bp.txQueues[i] = make(chan queuedDatagram, queueCap)
bp.tunQueues[i] = make(chan *overlay.Packet, queueCap)
}
}
func (bp *batchPipelines) Pool() *overlay.PacketPool {
if bp == nil || !bp.enabled {
return nil
}
return bp.pool
}
func (bp *batchPipelines) Enabled() bool {
return bp != nil && bp.enabled
}
func (bp *batchPipelines) batchSizeHint() int {
if bp == nil || bp.batchSize <= 0 {
return 1
}
return bp.batchSize
}
func (bp *batchPipelines) rxQueue(i int) chan *overlay.Packet {
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.rxQueues) {
return nil
}
return bp.rxQueues[i]
}
func (bp *batchPipelines) txQueue(i int) chan queuedDatagram {
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.txQueues) {
return nil
}
return bp.txQueues[i]
}
func (bp *batchPipelines) tunQueue(i int) chan *overlay.Packet {
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.tunQueues) {
return nil
}
return bp.tunQueues[i]
}
func (bp *batchPipelines) txQueueLen(i int) int {
q := bp.txQueue(i)
if q == nil {
return 0
}
return len(q)
}
func (bp *batchPipelines) tunQueueLen(i int) int {
q := bp.tunQueue(i)
if q == nil {
return 0
}
return len(q)
}
func (bp *batchPipelines) enqueueRx(i int, pkt *overlay.Packet) bool {
q := bp.rxQueue(i)
if q == nil {
return false
}
q <- pkt
return true
}
func (bp *batchPipelines) enqueueTx(i int, pkt *overlay.Packet, addr netip.AddrPort) bool {
q := bp.txQueue(i)
if q == nil {
return false
}
q <- queuedDatagram{packet: pkt, addr: addr}
return true
}
func (bp *batchPipelines) enqueueTun(i int, pkt *overlay.Packet) bool {
q := bp.tunQueue(i)
if q == nil {
return false
}
q <- pkt
return true
}
func (bp *batchPipelines) newPacket() *overlay.Packet {
if bp == nil || !bp.enabled || bp.pool == nil {
return nil
}
return bp.pool.Get()
}

109
bits.go
View File

@@ -9,13 +9,14 @@ type Bits struct {
length uint64 length uint64
current uint64 current uint64
bits []bool bits []bool
firstSeen bool
lostCounter metrics.Counter lostCounter metrics.Counter
dupeCounter metrics.Counter dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter outOfWindowCounter metrics.Counter
} }
func NewBits(bits uint64) *Bits { func NewBits(bits uint64) *Bits {
b := &Bits{ return &Bits{
length: bits, length: bits,
bits: make([]bool, bits, bits), bits: make([]bool, bits, bits),
current: 0, current: 0,
@@ -23,37 +24,34 @@ func NewBits(bits uint64) *Bits {
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil), outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
} }
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
return b
} }
func (b *Bits) Check(l *logrus.Logger, i uint64) bool { func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true. // If i is the next number, return true.
if i > b.current { if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
return true return true
} }
// If i is within the window, check if it's been set already. // If i is within the window, check if it's been set already. The first window will fail this check
if i > b.current-b.length || i < b.length && b.current < b.length { if i > b.current-b.length {
return !b.bits[i%b.length]
}
// If i is within the first window
if i < b.length {
return !b.bits[i%b.length] return !b.bits[i%b.length]
} }
// Not within the window // Not within the window
if l.Level >= logrus.DebugLevel { l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
}
return false return false
} }
func (b *Bits) Update(l *logrus.Logger, i uint64) bool { func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current. // If i is the next number, return true and update current.
if i == b.current+1 { if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter // Report missed packets, we can only understand what was missed after the first window has been gone through
// The very first window can only be tracked as lost once we are on the 2nd window or greater if i > b.length && b.bits[i%b.length] == false {
if b.bits[i%b.length] == false && i > b.length {
b.lostCounter.Inc(1) b.lostCounter.Inc(1)
} }
b.bits[i%b.length] = true b.bits[i%b.length] = true
@@ -61,32 +59,61 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return true return true
} }
// If i is a jump, adjust the window, record lost, update current, and return true // If i packet is greater than current but less than the maximum length of our bitmap,
if i > b.current { // flip everything in between to false and move ahead.
lost := int64(0) if i > b.current && i < b.current+b.length {
// Zero out the bits between the current and the new counter value, limited by the window size, // In between current and i need to be zero'd to allow those packets to come in later
// since the window is shifting for n := b.current + 1; n < i; n++ {
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
lost++
}
b.bits[n%b.length] = false b.bits[n%b.length] = false
} }
// Only record any skipped packets as a result of the window moving further than the window length b.bits[i%b.length] = true
// Any loss within the new window will be accounted for in future calls b.current = i
lost += max(0, int64(i-b.current-b.length)) //l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
return true
}
// If i is greater than the delta between current and the total length of our bitmap,
// just flip everything in the map and move ahead.
if i >= b.current+b.length {
// The current window loss will be accounted for later, only record the jump as loss up until then
lost := maxInt64(0, int64(i-b.current-b.length))
//TODO: explain this
if b.current == 0 {
lost++
}
for n := range b.bits {
// Don't want to count the first window as a loss
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
//if b.bits[n] == false {
// lost++
//}
b.bits[n] = false
}
b.lostCounter.Inc(lost) b.lostCounter.Inc(lost)
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
Debug("Receive window")
}
b.bits[i%b.length] = true b.bits[i%b.length] = true
b.current = i b.current = i
return true return true
} }
// If i is within the current window but below the current counter, // Allow for the 0 packet to come in within the first window
// Check to see if it's a duplicate if i == 0 && b.firstSeen == false && b.current < b.length {
if i > b.current-b.length || i < b.length && b.current < b.length { b.firstSeen = true
if b.current == i || b.bits[i%b.length] == true { b.bits[i%b.length] = true
return true
}
// If i is within the window of current minus length (the total pat window size),
// allow it and flip to true but to NOT change current. We also have to account for the first window
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
if b.current == i {
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window") Debug("Receive window")
@@ -95,8 +122,18 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false return false
} }
if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true b.bits[i%b.length] = true
return true return true
} }
// In all other cases, fail and don't change current. // In all other cases, fail and don't change current.
@@ -110,3 +147,11 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
} }
return false return false
} }
func maxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@@ -15,41 +15,48 @@ func TestBits(t *testing.T) {
assert.Len(t, b.bits, 10) assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work. // This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1)) assert.True(t, b.Check(l, 1))
assert.True(t, b.Update(l, 1)) u := b.Update(l, 1)
assert.True(t, u)
assert.EqualValues(t, 1, b.current) assert.EqualValues(t, 1, b.current)
g := []bool{true, true, false, false, false, false, false, false, false, false} g := []bool{false, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two // Receive two
assert.True(t, b.Check(l, 2)) assert.True(t, b.Check(l, 2))
assert.True(t, b.Update(l, 2)) u = b.Update(l, 2)
assert.True(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
g = []bool{true, true, true, false, false, false, false, false, false, false} g = []bool{false, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two again - it will fail // Receive two again - it will fail
assert.False(t, b.Check(l, 2)) assert.False(t, b.Check(l, 2))
assert.False(t, b.Update(l, 2)) u = b.Update(l, 2)
assert.False(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element // Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15)) assert.True(t, b.Check(l, 15))
assert.True(t, b.Update(l, 15)) u = b.Update(l, 15)
assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false} g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window // Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14)) assert.True(t, b.Check(l, 14))
assert.True(t, b.Update(l, 14)) u = b.Update(l, 14)
assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window // Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(l, 5)) assert.False(t, b.Check(l, 5))
assert.False(t, b.Update(l, 5)) u = b.Update(l, 5)
assert.False(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
@@ -62,29 +69,10 @@ func TestBits(t *testing.T) {
// Walk through a few windows in order // Walk through a few windows in order
b = NewBits(10) b = NewBits(10)
for i := uint64(1); i <= 100; i++ { for i := uint64(0); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i)
} }
assert.False(t, b.Check(l, 1), "Out of window check")
}
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
} }
func TestBitsDupeCounter(t *testing.T) { func TestBitsDupeCounter(t *testing.T) {
@@ -136,7 +124,8 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
assert.False(t, b.Update(l, 0)) assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost //tODO: make sure lostcounter doesn't increase in orderly increment
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
} }
@@ -148,6 +137,8 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20)) assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21)) assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22)) assert.True(t, b.Update(l, 22))
@@ -158,7 +149,7 @@ func TestBitsLostCounter(t *testing.T) {
assert.True(t, b.Update(l, 27)) assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28)) assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29)) assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -167,6 +158,8 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9)) assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets // 10 will set 0 index, 0 was already set, no lost packets
@@ -221,62 +214,6 @@ func TestBitsLostCounter(t *testing.T) {
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
} }
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 5))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 6))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) { func BenchmarkBits(b *testing.B) {
z := NewBits(10) z := NewBits(10)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {

View File

@@ -1,8 +1,10 @@
package cert package cert
import ( import (
"encoding/hex"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"time"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@@ -138,6 +140,101 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
} }
} }
// Backward compatibility functions for older API
func MarshalX25519PublicKey(b []byte) []byte {
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
}
func MarshalX25519PrivateKey(b []byte) []byte {
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
}
func MarshalPublicKey(curve Curve, b []byte) []byte {
return MarshalPublicKeyToPEM(curve, b)
}
func MarshalPrivateKey(curve Curve, b []byte) []byte {
return MarshalPrivateKeyToPEM(curve, b)
}
// NebulaCertificate is a compatibility wrapper for the old API
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
cert Certificate
}
// NebulaCertificateDetails is a compatibility wrapper for certificate details
type NebulaCertificateDetails struct {
Name string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer []byte
Curve Curve
}
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
c, rest, err := UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, rest, err
}
issuerBytes, err := func() ([]byte, error) {
issuer := c.Issuer()
if issuer == "" {
return nil, nil
}
decoded, err := hex.DecodeString(issuer)
if err != nil {
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
}
return decoded, nil
}()
if err != nil {
return nil, rest, err
}
pubKey := c.PublicKey()
if pubKey != nil {
pubKey = append([]byte(nil), pubKey...)
}
sig := c.Signature()
if sig != nil {
sig = append([]byte(nil), sig...)
}
return &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: c.Name(),
NotBefore: c.NotBefore(),
NotAfter: c.NotAfter(),
PublicKey: pubKey,
IsCA: c.IsCA(),
Issuer: issuerBytes,
Curve: c.Curve(),
},
Signature: sig,
cert: c,
}, rest, nil
}
// IssuerString returns the issuer in hex format for compatibility
func (n *NebulaCertificate) IssuerString() string {
if n.Details.Issuer == nil {
return ""
}
return hex.EncodeToString(n.Details.Issuer)
}
// Certificate returns the underlying certificate (read-only)
func (n *NebulaCertificate) Certificate() Certificate {
return n.cert
}
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non // UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
// consumed data or an error on failure // consumed data or an error on failure
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {

View File

@@ -173,26 +173,23 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
var passphrase []byte var passphrase []byte
if !isP11 && *cf.encryption { if !isP11 && *cf.encryption {
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE")) for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
return fmt.Errorf("out-key must be encrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading passphrase: %s", err)
}
if len(passphrase) > 0 {
break
}
}
if len(passphrase) == 0 { if len(passphrase) == 0 {
for i := 0; i < 5; i++ { return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
return fmt.Errorf("out-key must be encrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading passphrase: %s", err)
}
if len(passphrase) > 0 {
break
}
}
if len(passphrase) == 0 {
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
}
} }
} }

View File

@@ -171,17 +171,6 @@ func Test_ca(t *testing.T) {
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// test encrypted key with passphrase environment variable
os.Remove(keyF.Name())
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, ca(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// read encrypted key file and verify default params // read encrypted key file and verify default params
rb, _ = os.ReadFile(keyF.Name()) rb, _ = os.ReadFile(keyF.Name())
k, _ := pem.Decode(rb) k, _ := pem.Decode(rb)

View File

@@ -5,28 +5,10 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"runtime/debug"
"strings"
) )
// A version string that can be set with
//
// -ldflags "-X main.Build=SOMEVERSION"
//
// at compile-time.
var Build string var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
type helpError struct { type helpError struct {
s string s string
} }

View File

@@ -43,7 +43,7 @@ type signFlags struct {
func newSignFlags() *signFlags { func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {} sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA") sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
@@ -116,28 +116,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
// naively attempt to decode the private key as though it is not encrypted // naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if errors.Is(err, cert.ErrPrivateKeyEncrypted) { if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
// ask for a passphrase until we get one
var passphrase []byte var passphrase []byte
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE")) for i := 0; i < 5; i++ {
if len(passphrase) == 0 { out.Write([]byte("Enter passphrase: "))
// ask for a passphrase until we get one passphrase, err = pr.ReadPassword()
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if errors.Is(err, ErrNoTerminal) { if errors.Is(err, ErrNoTerminal) {
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
} else if err != nil { } else if err != nil {
return fmt.Errorf("error reading password: %s", err) return fmt.Errorf("error reading password: %s", err)
}
if len(passphrase) > 0 {
break
}
} }
if len(passphrase) == 0 {
return fmt.Errorf("cannot open encrypted ca-key without passphrase") if len(passphrase) > 0 {
break
} }
} }
if len(passphrase) == 0 {
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
}
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
if err != nil { if err != nil {
return fmt.Errorf("error while parsing encrypted ca-key: %s", err) return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
@@ -167,10 +165,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("ca certificate is expired") return fmt.Errorf("ca certificate is expired")
} }
if version == 0 {
version = caCert.Version()
}
// if no duration is given, expire one second before the root expires // if no duration is given, expire one second before the root expires
if *sf.duration <= 0 { if *sf.duration <= 0 {
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
@@ -283,19 +277,21 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
notBefore := time.Now() notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration) notAfter := notBefore.Add(*sf.duration)
switch version { if version == 0 || version == cert.Version1 {
case cert.Version1: // Make sure we at least have an ip
// Make sure we have only one ipv4 address
if len(v4Networks) != 1 { if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
} }
if len(v6Networks) > 0 { if version == cert.Version1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses") // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
} if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
}
if len(v6UnsafeNetworks) > 0 { if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
} }
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
@@ -325,8 +321,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
crts = append(crts, nc) crts = append(crts, nc)
}
case cert.Version2: if version == 0 || version == cert.Version2 {
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
Version: cert.Version2, Version: cert.Version2,
Name: *sf.name, Name: *sf.name,
@@ -354,9 +351,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
crts = append(crts, nc) crts = append(crts, nc)
default:
// this should be unreachable
return fmt.Errorf("invalid version: %d", version)
} }
if !isP11 && *sf.inPubPath == "" { if !isP11 && *sf.inPubPath == "" {

View File

@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
" -unsafe-networks string\n"+ " -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+ " -version uint\n"+
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n", " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
ob.String(), ob.String(),
) )
} }
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -379,15 +379,6 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// test with the proper password in the environment
os.Remove(crtF.Name())
os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the wrong password // test with the wrong password
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
@@ -398,17 +389,6 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// test with the wrong password in environment
ob.Reset()
eb.Reset()
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the user not entering a password // test with the user not entering a password
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()

View File

@@ -4,8 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
@@ -20,17 +18,6 @@ import (
// at compile-time. // at compile-time.
var Build string var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() { func main() {
serviceFlag := flag.String("service", "", "Control the system service.") serviceFlag := flag.String("service", "", "Control the system service.")
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")

View File

@@ -4,8 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
@@ -20,17 +18,6 @@ import (
// at compile-time. // at compile-time.
var Build string var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() { func main() {
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config") configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")

View File

@@ -17,7 +17,7 @@ import (
"dario.cat/mergo" "dario.cat/mergo"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
type C struct { type C struct {

View File

@@ -10,7 +10,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {

View File

@@ -50,6 +50,11 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
} }
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{ hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs, CipherSuite: ncs,
Random: rand.Reader, Random: rand.Reader,
@@ -69,7 +74,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
ci := &ConnectionState{ ci := &ConnectionState{
H: hs, H: hs,
initiator: initiator, initiator: initiator,
window: NewBits(ReplayWindow), window: b,
myCert: crt, myCert: crt,
} }
// always start the counter from 2, as packet 1 and packet 2 are handshake packets. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.

View File

@@ -174,10 +174,6 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap return c.f.hostMap
} }
func (c *Control) GetF() *Interface {
return c.f
}
func (c *Control) GetCertState() *CertState { func (c *Control) GetCertState() *CertState {
return c.f.pki.getCertState() return c.f.pki.getCertState()
} }

View File

@@ -20,7 +20,7 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
@@ -97,41 +97,6 @@ func TestGoodHandshake(t *testing.T) {
theirControl.Stop() theirControl.Stop()
} }
func TestGoodHandshakeNoOverlap(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
empty := []byte{}
t.Log("do something to cause a handshake")
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
myControl.WaitForType(header.Test, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestWrongResponderHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -499,35 +464,6 @@ func TestRelays(t *testing.T) {
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
} }
func TestRelaysDontCareAboutIps(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
}
func TestReestablishRelays(t *testing.T) { func TestReestablishRelays(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
@@ -1291,109 +1227,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
o := m{
"static_host_map": m{
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
},
"lighthouse": m{
"hosts": []string{lhVpnIpNet[0].Addr().String()},
"local_allow_list": m{
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
"10.0.0.0/24": true,
"::/0": false,
},
},
}
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, lhControl, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
lhControl.Start()
myControl.Start()
theirControl.Start()
t.Log("Stand up an ipv6 tunnel between me and them")
assert.True(t, myVpnIpNet[1].Addr().Is6())
assert.True(t, theirVpnIpNet[1].Addr().Is6())
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
lhControl.Stop()
myControl.Stop()
theirControl.Stop()
}
func TestGoodHandshakeUnsafeDest(t *testing.T) {
unsafePrefix := "192.168.6.0/24"
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
myCfg := m{
"tun": m{
"unsafe_routes": []m{route},
},
}
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
t.Logf("my config %v", myConfig)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
spookyDest := netip.MustParseAddr("192.168.6.4")
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet so that we can play with it")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
myControl.InjectUDPPacket(badPacket)
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see my cached packet come through")
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
//reply
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
//wait for reply
theirControl.WaitForType(1, 0, myControl)
theirCachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -22,14 +22,15 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "gopkg.in/yaml.v3"
"go.yaml.in/yaml/v3"
) )
type m = map[string]any type m = map[string]any
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") { for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -55,54 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
budpIp[3] = 239 budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
} }
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides) _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
}
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
}
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
firewallInbound := []m{{
"proto": "any",
"port": "any",
"host": "any",
}}
var unsafeNetworks []netip.Prefix
if sUnsafeNetworks != "" {
firewallInbound = []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "0.0.0.0/0",
}}
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
unsafeNetworks = append(unsafeNetworks, x)
}
}
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
caB, err := caCrt.MarshalPEM() caB, err := caCrt.MarshalPEM()
if err != nil { if err != nil {
@@ -122,7 +76,11 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": "any", "port": "any",
"host": "any", "host": "any",
}}, }},
"inbound": firewallInbound, "inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
}, },
//"handshakes": m{ //"handshakes": m{
// "try_interval": "1s", // "try_interval": "1s",
@@ -308,10 +266,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
// Get both host infos // Get both host infos
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct // Check that both vpn and real addr are correct
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")

View File

@@ -318,50 +318,3 @@ func TestCertMismatchCorrection(t *testing.T) {
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestCrossStackRelaysWork(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
//myVpnV4 := myVpnIpNet[0]
myVpnV6 := myVpnIpNet[1]
relayVpnV4 := relayVpnIpNet[0]
relayVpnV6 := relayVpnIpNet[1]
theirVpnV6 := theirVpnIpNet[0]
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//t.Log("finish up")
//myControl.Stop()
//theirControl.Stop()
//relayControl.Stop()
}

View File

@@ -383,9 +383,8 @@ firewall:
# host: `any` or a literal hostname, ie `test-host` # host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group` # group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
# This can be used to filter destinations when using unsafe_routes.
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true. # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case. # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
# ca_name: An issuing CA name # ca_name: An issuing CA name

View File

@@ -8,7 +8,6 @@ import (
"hash/fnv" "hash/fnv"
"net/netip" "net/netip"
"reflect" "reflect"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -23,7 +22,7 @@ import (
) )
type FirewallInterface interface { type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
} }
type conn struct { type conn struct {
@@ -248,11 +247,22 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
} }
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
if localIp.IsValid() {
lIp = localIp.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload. // We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf( ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
) )
f.rules += ruleString + "\n" f.rules += ruleString + "\n"
@@ -260,7 +270,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added") Info("Firewall rule added")
var ( var (
@@ -287,7 +297,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto) return fmt.Errorf("unknown protocol %v", proto)
} }
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
} }
// GetRuleHash returns a hash representation of all inbound and outbound rules // GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -327,6 +337,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
} }
for i, t := range rs { for i, t := range rs {
var groups []string
r, err := convertRule(l, t, table, i) r, err := convertRule(l, t, table, i)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err) return fmt.Errorf("%s rule #%v; %s", table, i, err)
@@ -336,10 +347,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
} }
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" { if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i) return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
} }
if len(r.Groups) > 0 {
groups = r.Groups
}
if r.Group != "" {
// Check if we have both groups and group provided in the rule config
if len(groups) > 0 {
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
}
groups = []string{r.Group}
}
var sPort, errPort string var sPort, errPort string
if r.Code != "" { if r.Code != "" {
errPort = "code" errPort = "code"
@@ -368,25 +392,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
} }
if r.Cidr != "" && r.Cidr != "any" { var cidr netip.Prefix
_, err = netip.ParsePrefix(r.Cidr) if r.Cidr != "" {
cidr, err = netip.ParsePrefix(r.Cidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
} }
} }
if r.LocalCidr != "" && r.LocalCidr != "any" { var localCidr netip.Prefix
_, err = netip.ParsePrefix(r.LocalCidr) if r.LocalCidr != "" {
localCidr, err = netip.ParsePrefix(r.LocalCidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
} }
} }
if warning := r.sanity(); warning != nil { err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
l.Warnf("%s rule #%v; %s", table, i, warning)
}
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err) return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
} }
@@ -395,45 +417,30 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return nil return nil
} }
var ErrUnknownNetworkType = errors.New("unknown network type") var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
var ErrPeerRejected = errors.New("remote address is not within a network that we handle") var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table") var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It // Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped. // returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet // Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) { if f.inConns(fp, h, caPool, localCache) {
return nil return nil
} }
// Make sure remote address matches nebula certificate, and determine how to treat it // Make sure remote address matches nebula certificate
if h.networks == nil { if h.networks != nil {
// Simple case: Certificate has one address and no unsafe networks if !h.networks.Contains(fp.RemoteAddr) {
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
} else { } else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr) // Simple case: Certificate has one address and no unsafe networks
if !ok { if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
switch nwType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
@@ -483,11 +490,9 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
} }
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) bool {
if localCache != nil { if localCache != nil && localCache.Has(fp) {
if _, ok := localCache[fp]; ok { return true
return true
}
} }
conntrack := f.Conntrack conntrack := f.Conntrack
conntrack.Lock() conntrack.Lock()
@@ -552,7 +557,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
conntrack.Unlock() conntrack.Unlock()
if localCache != nil { if localCache != nil {
localCache[fp] = struct{}{} localCache.Add(fp)
} }
return true return true
@@ -633,7 +638,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
return false return false
} }
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port was lower than end port") return fmt.Errorf("start port was lower than end port")
} }
@@ -646,7 +651,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
} }
} }
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil { if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
return err return err
} }
} }
@@ -677,7 +682,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
return fp[firewall.PortAny].match(p, c, caPool) return fp[firewall.PortAny].match(p, c, caPool)
} }
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error { func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
fr := func() *FirewallRule { fr := func() *FirewallRule {
return &FirewallRule{ return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR), Hosts: make(map[string]*firewallLocalCIDR),
@@ -691,14 +696,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
fc.Any = fr() fc.Any = fr()
} }
return fc.Any.addRule(f, groups, host, cidr, localCidr) return fc.Any.addRule(f, groups, host, ip, localIp)
} }
if caSha != "" { if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok { if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr() fc.CAShas[caSha] = fr()
} }
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr) err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@@ -708,7 +713,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
if _, ok := fc.CANames[caName]; !ok { if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr() fc.CANames[caName] = fr()
} }
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr) err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@@ -740,24 +745,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
return fc.CANames[s.Certificate.Name()].match(p, c) return fc.CANames[s.Certificate.Name()].match(p, c)
} }
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error { func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR { flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{ return &firewallLocalCIDR{
LocalCIDR: new(bart.Lite), LocalCIDR: new(bart.Lite),
} }
} }
if fr.isAny(groups, host, cidr) { if fr.isAny(groups, host, ip) {
if fr.Any == nil { if fr.Any == nil {
fr.Any = flc() fr.Any = flc()
} }
return fr.Any.addRule(f, localCidr) return fr.Any.addRule(f, localCIDR)
} }
if len(groups) > 0 { if len(groups) > 0 {
nlc := flc() nlc := flc()
err := nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
@@ -773,34 +778,30 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localC
if nlc == nil { if nlc == nil {
nlc = flc() nlc = flc()
} }
err := nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
fr.Hosts[host] = nlc fr.Hosts[host] = nlc
} }
if cidr != "" { if ip.IsValid() {
c, err := netip.ParsePrefix(cidr) nlc, _ := fr.CIDR.Get(ip)
if err != nil {
return err
}
nlc, _ := fr.CIDR.Get(c)
if nlc == nil { if nlc == nil {
nlc = flc() nlc = flc()
} }
err = nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
fr.CIDR.Insert(c, nlc) fr.CIDR.Insert(ip, nlc)
} }
return nil return nil
} }
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool { func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
if len(groups) == 0 && host == "" && cidr == "" { if len(groups) == 0 && host == "" && !ip.IsValid() {
return true return true
} }
@@ -814,7 +815,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
return true return true
} }
if cidr == "any" { if ip.IsValid() && ip.Bits() == 0 {
return true return true
} }
@@ -866,13 +867,8 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
return false return false
} }
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if localCidr == "any" { if !localIp.IsValid() {
flc.Any = true
return nil
}
if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
flc.Any = true flc.Any = true
return nil return nil
@@ -883,13 +879,12 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
} }
return nil return nil
} else if localIp.Bits() == 0 {
flc.Any = true
return nil
} }
c, err := netip.ParsePrefix(localCidr) flc.LocalCIDR.Insert(localIp)
if err != nil {
return err
}
flc.LocalCIDR.Insert(c)
return nil return nil
} }
@@ -910,6 +905,7 @@ type rule struct {
Code string Code string
Proto string Proto string
Host string Host string
Group string
Groups []string Groups []string
Cidr string Cidr string
LocalCidr string LocalCidr string
@@ -951,8 +947,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
m["group"] = v[0] m["group"] = v[0]
} }
r.Group = toString("group", m)
singleGroup := toString("group", m)
if rg, ok := m["groups"]; ok { if rg, ok := m["groups"]; ok {
switch reflect.TypeOf(rg).Kind() { switch reflect.TypeOf(rg).Kind() {
@@ -969,60 +964,9 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
} }
} }
//flatten group vs groups
if singleGroup != "" {
// Check if we have both groups and group provided in the rule config
if len(r.Groups) > 0 {
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
}
r.Groups = []string{singleGroup}
}
return r, nil return r, nil
} }
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
func (r *rule) sanity() error {
//port, proto, local_cidr are AND, no need to check here
//ca_sha and ca_name don't have a wildcard value, no need to check here
groupsEmpty := len(r.Groups) == 0
hostEmpty := r.Host == ""
cidrEmpty := r.Cidr == ""
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
return nil //no content!
}
groupsHasAny := slices.Contains(r.Groups, "any")
if groupsHasAny && len(r.Groups) > 1 {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
}
if r.Host == "any" {
if !groupsEmpty {
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
}
if !cidrEmpty {
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
}
}
if groupsHasAny {
if !hostEmpty && r.Host != "any" {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
}
if !cidrEmpty {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
}
}
//todo alert on cidr-any
return nil
}
func parsePort(s string) (startPort, endPort int32, err error) { func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" { if s == "any" {
startPort = firewall.PortAny startPort = firewall.PortAny

View File

@@ -1,6 +1,7 @@
package firewall package firewall
import ( import (
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -9,13 +10,58 @@ import (
// ConntrackCache is used as a local routine cache to know if a given flow // ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table. // has been seen in the conntrack table.
type ConntrackCache map[Packet]struct{} type ConntrackCache struct {
mu sync.Mutex
entries map[Packet]struct{}
}
func newConntrackCache() *ConntrackCache {
return &ConntrackCache{entries: make(map[Packet]struct{})}
}
func (c *ConntrackCache) Has(p Packet) bool {
if c == nil {
return false
}
c.mu.Lock()
_, ok := c.entries[p]
c.mu.Unlock()
return ok
}
func (c *ConntrackCache) Add(p Packet) {
if c == nil {
return
}
c.mu.Lock()
c.entries[p] = struct{}{}
c.mu.Unlock()
}
func (c *ConntrackCache) Len() int {
if c == nil {
return 0
}
c.mu.Lock()
l := len(c.entries)
c.mu.Unlock()
return l
}
func (c *ConntrackCache) Reset(capHint int) {
if c == nil {
return
}
c.mu.Lock()
c.entries = make(map[Packet]struct{}, capHint)
c.mu.Unlock()
}
type ConntrackCacheTicker struct { type ConntrackCacheTicker struct {
cacheV uint64 cacheV uint64
cacheTick atomic.Uint64 cacheTick atomic.Uint64
cache ConntrackCache cache *ConntrackCache
} }
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
@@ -23,9 +69,7 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
return nil return nil
} }
c := &ConntrackCacheTicker{ c := &ConntrackCacheTicker{cache: newConntrackCache()}
cache: ConntrackCache{},
}
go c.tick(d) go c.tick(d)
@@ -41,17 +85,17 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
// Get checks if the cache ticker has moved to the next version before returning // Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map. // the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache {
if c == nil { if c == nil {
return nil return nil
} }
if tick := c.cacheTick.Load(); tick != c.cacheV { if tick := c.cacheTick.Load(); tick != c.cacheV {
c.cacheV = tick c.cacheV = tick
if ll := len(c.cache); ll > 0 { if ll := c.cache.Len(); ll > 0 {
if l.Level == logrus.DebugLevel { if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache") l.WithField("len", ll).Debug("resetting conntrack cache")
} }
c.cache = make(ConntrackCache, ll) c.cache.Reset(ll)
} }
} }

View File

@@ -8,8 +8,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@@ -73,114 +71,85 @@ func TestFirewall_AddRule(t *testing.T) {
ti6, err := netip.ParsePrefix("fd12::34/128") ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0") anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0") anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -205,10 +174,10 @@ func TestFirewall_Drop(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
} }
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -227,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -257,9 +226,6 @@ func TestFirewall_DropV6(t *testing.T) {
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"), LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"), RemoteAddr: netip.MustParseAddr("fd12::34"),
@@ -284,10 +250,10 @@ func TestFirewall_DropV6(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
} }
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -306,28 +272,28 @@ func TestFirewall_DropV6(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -338,12 +304,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
pfix := netip.MustParsePrefix("172.1.1.1/32") pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128") pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
@@ -487,8 +453,6 @@ func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -514,7 +478,7 @@ func TestFirewall_Drop2(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
c1 := cert.CachedCertificate{ c1 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -529,10 +493,10 @@ func TestFirewall_Drop2(t *testing.T) {
peerCert: &c1, peerCert: &c1,
}, },
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
@@ -546,8 +510,6 @@ func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -579,7 +541,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
c2 := cert.CachedCertificate{ c2 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -594,7 +556,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h2.buildNetworks(myVpnNetworksTable, c2.Certificate) h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
c3 := cert.CachedCertificate{ c3 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -609,11 +571,11 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h3.buildNetworks(myVpnNetworksTable, c3.Certificate) h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
@@ -627,7 +589,7 @@ func TestFirewall_Drop3(t *testing.T) {
// Test a remote address match // Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
@@ -635,8 +597,6 @@ func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"), LocalAddr: netip.MustParseAddr("fd12::34"),
@@ -660,12 +620,12 @@ func TestFirewall_Drop3V6(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match // Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool() cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -673,8 +633,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -701,10 +659,10 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -717,7 +675,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -726,7 +684,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -738,8 +696,6 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
c := cert.CachedCertificate{ c := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -761,11 +717,11 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()}, vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
@@ -959,28 +915,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr // Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8") cidr := netip.MustParsePrefix("10.0.0.0/8")
@@ -988,14 +944,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6 // Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8") cidr6 := netip.MustParsePrefix("fd00::/8")
@@ -1003,75 +959,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6 // Test adding rule with local_cidr ipv6
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error // Test Add error
conf = config.NewC(l) conf = config.NewC(l)
@@ -1094,7 +1024,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err := convertRule(l, c, "test", 1) r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups) assert.Equal(t, "group1", r.Group)
// Ensure group array of > 1 is errord // Ensure group array of > 1 is errord
ob.Reset() ob.Reset()
@@ -1114,228 +1044,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err = convertRule(l, c, "test", 1) r, err = convertRule(l, c, "test", 1)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups) assert.Equal(t, "group1", r.Group)
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"host": "bob"},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range noWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
}
yesWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
c["host"] = "any"
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
//reset the list
yesWarningPlease = []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
r.Groups = append(r.Groups, "any")
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
}
type testcase struct {
h *HostInfo
p firewall.Packet
c cert.Certificate
err error
}
func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
}
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
c1 := dummyCert{
name: "host1",
networks: theirPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c1,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
}
for i := range theirPrefixes {
h.vpnAddrs[i] = theirPrefixes[i].Addr()
}
h.buildNetworks(setup.myVpnNetworksTable, &c1)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: theirPrefixes[0].Addr(),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
return testcase{
h: &h,
p: p,
c: &c1,
err: err,
}
}
type testsetup struct {
c dummyCert
myVpnNetworksTable *bart.Lite
fw *Firewall
}
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{
name: "me",
networks: myPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
return newSetupFromCert(t, l, c)
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{
c: c,
fw: fw,
myVpnNetworksTable: myVpnNetworksTable,
}
}
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
t.Run("allow inbound all matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
tc.Test(t, setup.fw)
})
t.Run("allow inbound local matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
tc.Test(t, setup.fw)
})
t.Run("block inbound remote mismatched", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("Block a vpn peer packet", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
tc.Test(t, setup.fw)
})
twoPrefixes := []netip.Prefix{
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
}
t.Run("allow inbound one matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, twoPrefixes...)
tc.Test(t, setup.fw)
})
t.Run("block inbound multimismatch", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
t.Parallel()
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
tc := buildTestCase(setup2, nil, twoPrefixes...)
tc.p.RemoteAddr = twoPrefixes[1].Addr()
tc.Test(t, setup2.fw)
})
t.Run("allow inbound unsafe route", func(t *testing.T) {
t.Parallel()
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
c := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
unsafeNetworks: []netip.Prefix{unsafePrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
unsafeSetup := newSetupFromCert(t, l, c)
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass
})
} }
type addRuleCall struct { type addRuleCall struct {
@@ -1345,8 +1054,8 @@ type addRuleCall struct {
endPort int32 endPort int32
groups []string groups []string
host string host string
ip string ip netip.Prefix
localIp string localIp netip.Prefix
caName string caName string
caSha string caSha string
} }
@@ -1356,7 +1065,7 @@ type mockFirewall struct {
nextCallReturn error nextCallReturn error
} }
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
mf.lastCall = addRuleCall{ mf.lastCall = addRuleCall{
incoming: incoming, incoming: incoming,
proto: proto, proto: proto,

15
go.mod
View File

@@ -8,7 +8,7 @@ require (
github.com/armon/go-radix v1.0.0 github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0 github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.26.0 github.com/gaissmai/bart v0.25.0
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4 github.com/kardianos/service v1.2.4
@@ -22,17 +22,16 @@ require (
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.43.0
golang.org/x/crypto v0.45.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.47.0 golang.org/x/net v0.45.0
golang.org/x/sync v0.18.0 golang.org/x/sync v0.17.0
golang.org/x/sys v0.38.0 golang.org/x/sys v0.37.0
golang.org/x/term v0.37.0 golang.org/x/term v0.36.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.10 google.golang.org/protobuf v1.36.8
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
) )

30
go.sum
View File

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0= github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -155,15 +155,13 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +180,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -191,8 +189,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -209,11 +207,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -244,8 +242,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -2,6 +2,7 @@ package nebula
import ( import (
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
@@ -191,17 +192,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version() certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
vpnNetworks := remoteCert.Certificate.Networks()
anyVpnAddrsInCommon := false for _, network := range remoteCert.Certificate.Networks() {
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddr := network.Addr()
for i, network := range vpnNetworks { if f.myVpnAddrsTable.Contains(vpnAddr) {
if f.myVpnAddrsTable.Contains(network.Addr()) { f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -209,10 +210,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return return
} }
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) { // vpnAddrs outside our vpn networks are of no use to us, filter them out
anyVpnAddrsInCommon = true if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue
} }
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return
} }
if addr.IsValid() { if addr.IsValid() {
@@ -249,30 +264,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}, },
} }
msgRxL := f.l.WithFields(m{ f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
"vpnAddrs": vpnAddrs, WithField("certName", certName).
"udpAddr": addr, WithField("certVersion", certVersion).
"certName": certName, WithField("fingerprint", fingerprint).
"certVersion": certVersion, WithField("issuer", issuer).
"fingerprint": fingerprint, WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
"issuer": issuer, WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
"initiatorIndex": hs.Details.InitiatorIndex, Info("Handshake message received")
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil { if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available") Error("Unable to handshake with host because no certificate handshake bytes is available")
return return
} }
@@ -330,7 +341,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil { if err != nil {
@@ -571,22 +582,31 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
correctHostResponded := false var vpnAddrs []netip.Addr
anyVpnAddrsInCommon := false var filteredNetworks []netip.Prefix
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for _, network := range vpnNetworks {
for i, network := range vpnNetworks { // vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddrs[i] = network.Addr() vpnAddr := network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) { if !f.myVpnNetworksTable.Contains(vpnAddr) {
anyVpnAddrsInCommon = true continue
}
if hostinfo.vpnAddrs[0] == network.Addr() {
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
correctHostResponded = true
} }
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
} }
// Ensure the right host responded // Ensure the right host responded
if !correctHostResponded { if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr). WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
@@ -598,7 +618,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip // Create a new hostinfo/handshake for the intended vpn ip
//TODO is hostinfo.vpnAddrs[0] always the address to use?
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
@@ -625,7 +644,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -633,17 +652,12 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration). WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore)) WithField("sentCachedPackets", len(hh.packetStore)).
if anyVpnAddrsInCommon { Info("Handshake message received")
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
// Build up the radix for the firewall if we have subnets in the cert // Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)

View File

@@ -269,12 +269,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to // Don't relay to myself
if relay == vpnIp { if relay == vpnIp {
continue continue
} }
// Don't relay to myself // Don't relay through the host I'm trying to connect to
if hm.f.myVpnAddrsTable.Contains(relay) { if hm.f.myVpnAddrsTable.Contains(relay) {
continue continue
} }

View File

@@ -212,18 +212,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.relayForByIdx[idx] = r rs.relayForByIdx[idx] = r
} }
type NetworkType uint8
const (
NetworkTypeUnknown NetworkType = iota
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
NetworkTypeVPN
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
)
type HostInfo struct { type HostInfo struct {
remote netip.AddrPort remote netip.AddrPort
remotes *RemoteList remotes *RemoteList
@@ -237,8 +225,8 @@ type HostInfo struct {
// vpn networks but were removed because they are not usable // vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr vpnAddrs []netip.Addr
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. // networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[NetworkType] networks *bart.Lite
relayState RelayState relayState RelayState
// HandshakePacket records the packets used to create this hostinfo // HandshakePacket records the packets used to create this hostinfo
@@ -742,26 +730,20 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false return false
} }
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up. func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) { if len(networks) == 1 && len(unsafeNetworks) == 0 {
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { // Simple case, no CIDRTree needed
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) { return
return // Simple case, no BART needed
}
} }
i.networks = new(bart.Table[NetworkType]) i.networks = new(bart.Lite)
for _, network := range c.Networks() { for _, network := range networks {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
if myVpnNetworksTable.Contains(network.Addr()) { i.networks.Insert(nprefix)
i.networks.Insert(nprefix, NetworkTypeVPN)
} else {
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
}
} }
for _, network := range c.UnsafeNetworks() { for _, network := range unsafeNetworks {
i.networks.Insert(network, NetworkTypeUnsafe) i.networks.Insert(network)
} }
} }

119
inside.go
View File

@@ -2,16 +2,18 @@ package nebula
import ( import (
"net/netip" "net/netip"
"unsafe"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache *firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
@@ -120,10 +122,9 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
} }
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
// it does not check if it is within our vpn networks!
func (f *Interface) Handshake(vpnAddr netip.Addr) { func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.handshakeManager.GetOrHandshake(vpnAddr, nil) f.getOrHandshakeNoRouting(vpnAddr, nil)
} }
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
@@ -139,6 +140,7 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
destinationAddr := fwPacket.RemoteAddr destinationAddr := fwPacket.RemoteAddr
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
@@ -231,10 +233,9 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
} }
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr. // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })
@@ -336,9 +337,21 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
if ci.eKey == nil { if ci.eKey == nil {
return return
} }
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() target := remote
if !target.IsValid() {
target = hostinfo.remote
}
useRelay := !target.IsValid()
fullOut := out fullOut := out
var pkt *overlay.Packet
if !useRelay && f.batches.Enabled() {
pkt = f.batches.newPacket()
if pkt != nil {
out = pkt.Payload()[:0]
}
}
if useRelay { if useRelay {
if len(out) < header.Len { if len(out) < header.Len {
// out always has a capacity of mtu, but not always a length greater than the header.Len. // out always has a capacity of mtu, but not always a length greater than the header.Len.
@@ -372,41 +385,85 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
} }
var err error var err error
if len(p) > 0 && slicesOverlap(out, p) {
tmp := make([]byte, len(p))
copy(tmp, p)
p = tmp
}
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
if noiseutil.EncryptLockNeeded { if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock() ci.writeLock.Unlock()
} }
if err != nil { if err != nil {
if pkt != nil {
pkt.Release()
}
hostinfo.logger(f.l).WithError(err). hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c). WithField("udpAddr", target).WithField("counter", c).
WithField("attemptedCounter", c). WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet") Error("Failed to encrypt outgoing packet")
return return
} }
if remote.IsValid() { if target.IsValid() {
err = f.writers[q].WriteTo(out, remote) if pkt != nil {
if err != nil { pkt.Len = len(out)
hostinfo.logger(f.l).WithError(err). if f.l.Level >= logrus.DebugLevel {
WithField("udpAddr", remote).Error("Failed to write outgoing packet") f.l.WithFields(logrus.Fields{
} "queue": q,
} else if hostinfo.remote.IsValid() { "dest": target,
err = f.writers[q].WriteTo(out, hostinfo.remote) "payload_len": pkt.Len,
if err != nil { "use_batches": true,
hostinfo.logger(f.l).WithError(err). "remote_index": hostinfo.remoteIndexId,
WithField("udpAddr", remote).Error("Failed to write outgoing packet") }).Debug("enqueueing packet to UDP batch queue")
}
} else {
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
continue
} }
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) if f.tryQueuePacket(q, pkt, target) {
break return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"queue": q,
"dest": target,
}).Debug("failed to enqueue packet; falling back to immediate send")
}
f.writeImmediatePacket(q, pkt, target, hostinfo)
return
} }
if f.tryQueueDatagram(q, out, target) {
return
}
f.writeImmediate(q, out, target, hostinfo)
return
}
// fall back to relay path
if pkt != nil {
pkt.Release()
}
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
continue
}
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
break
} }
} }
// slicesOverlap reports whether the two byte slices share any portion of memory.
// cipher.AEAD.Seal requires plaintext and dst to live in disjoint regions.
func slicesOverlap(a, b []byte) bool {
if len(a) == 0 || len(b) == 0 {
return false
}
aStart := uintptr(unsafe.Pointer(&a[0]))
aEnd := aStart + uintptr(len(a))
bStart := uintptr(unsafe.Pointer(&b[0]))
bEnd := bStart + uintptr(len(b))
return aStart < bEnd && bStart < aEnd
}

View File

@@ -8,6 +8,7 @@ import (
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -21,7 +22,13 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
const mtu = 9001 const (
mtu = 9001
defaultGSOFlushInterval = 150 * time.Microsecond
defaultBatchQueueDepthFactor = 4
defaultGSOMaxSegments = 8
maxKernelGSOSegments = 64
)
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
@@ -36,6 +43,9 @@ type InterfaceConfig struct {
connectionManager *connectionManager connectionManager *connectionManager
DropLocalBroadcast bool DropLocalBroadcast bool
DropMulticast bool DropMulticast bool
EnableGSO bool
EnableGRO bool
GSOMaxSegments int
routines int routines int
MessageMetrics *MessageMetrics MessageMetrics *MessageMetrics
version string version string
@@ -47,6 +57,8 @@ type InterfaceConfig struct {
reQueryWait time.Duration reQueryWait time.Duration
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
BatchFlushInterval time.Duration
BatchQueueDepth int
l *logrus.Logger l *logrus.Logger
} }
@@ -84,9 +96,20 @@ type Interface struct {
version string version string
conntrackCacheTimeout time.Duration conntrackCacheTimeout time.Duration
batchQueueDepth int
enableGSO bool
enableGRO bool
gsoMaxSegments int
batchUDPQueueGauge metrics.Gauge
batchUDPFlushCounter metrics.Counter
batchTunQueueGauge metrics.Gauge
batchTunFlushCounter metrics.Counter
batchFlushInterval atomic.Int64
sendSem chan struct{}
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
batches batchPipelines
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
@@ -161,6 +184,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return nil, errors.New("no connection manager") return nil, errors.New("no connection manager")
} }
if c.GSOMaxSegments <= 0 {
c.GSOMaxSegments = defaultGSOMaxSegments
}
if c.GSOMaxSegments > maxKernelGSOSegments {
c.GSOMaxSegments = maxKernelGSOSegments
}
if c.BatchQueueDepth <= 0 {
c.BatchQueueDepth = c.routines * defaultBatchQueueDepthFactor
}
if c.BatchFlushInterval < 0 {
c.BatchFlushInterval = 0
}
if c.BatchFlushInterval == 0 && c.EnableGSO {
c.BatchFlushInterval = defaultGSOFlushInterval
}
cs := c.pki.getCertState() cs := c.pki.getCertState()
ifce := &Interface{ ifce := &Interface{
pki: c.pki, pki: c.pki,
@@ -186,6 +225,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
relayManager: c.relayManager, relayManager: c.relayManager,
connectionManager: c.connectionManager, connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
batchQueueDepth: c.BatchQueueDepth,
enableGSO: c.EnableGSO,
enableGRO: c.EnableGRO,
gsoMaxSegments: c.GSOMaxSegments,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics, messageMetrics: c.MessageMetrics,
@@ -198,8 +241,25 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
} }
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.batchUDPQueueGauge = metrics.GetOrRegisterGauge("batch.udp.queue_depth", nil)
ifce.batchUDPFlushCounter = metrics.GetOrRegisterCounter("batch.udp.flushes", nil)
ifce.batchTunQueueGauge = metrics.GetOrRegisterGauge("batch.tun.queue_depth", nil)
ifce.batchTunFlushCounter = metrics.GetOrRegisterCounter("batch.tun.flushes", nil)
ifce.batchFlushInterval.Store(int64(c.BatchFlushInterval))
ifce.sendSem = make(chan struct{}, c.routines)
ifce.batches.init(c.Inside, c.routines, c.BatchQueueDepth, c.GSOMaxSegments)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
if c.l.Level >= logrus.DebugLevel {
c.l.WithFields(logrus.Fields{
"enableGSO": c.EnableGSO,
"enableGRO": c.EnableGRO,
"gsoMaxSegments": c.GSOMaxSegments,
"batchQueueDepth": c.BatchQueueDepth,
"batchFlush": c.BatchFlushInterval,
"batching": ifce.batches.Enabled(),
}).Debug("initialized batch pipelines")
}
ifce.connectionManager.intf = ifce ifce.connectionManager.intf = ifce
@@ -222,13 +282,6 @@ func (f *Interface) activate() {
WithField("boringcrypto", boringEnabled()). WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active") Info("Nebula interface is active")
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
f.routines = 1
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
}
}
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues // Prepare n tun queues
@@ -255,6 +308,18 @@ func (f *Interface) run() {
go f.listenOut(i) go f.listenOut(i)
} }
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("batching", f.batches.Enabled()).Debug("starting interface run loops")
}
if f.batches.Enabled() {
for i := 0; i < f.routines; i++ {
go f.runInsideBatchWorker(i)
go f.runTunWriteQueue(i)
go f.runSendQueue(i)
}
}
// Launch n queues to read packets from tun dev // Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i) go f.listenIn(f.readers[i], i)
@@ -286,6 +351,17 @@ func (f *Interface) listenOut(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread() runtime.LockOSThread()
if f.batches.Enabled() {
if br, ok := reader.(overlay.BatchReader); ok {
f.listenInBatchLocked(reader, br, i)
return
}
}
f.listenInLegacyLocked(reader, i)
}
func (f *Interface) listenInLegacyLocked(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu) packet := make([]byte, mtu)
out := make([]byte, mtu) out := make([]byte, mtu)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
@@ -309,6 +385,581 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
} }
} }
func (f *Interface) listenInBatchLocked(raw io.ReadWriteCloser, reader overlay.BatchReader, i int) {
pool := f.batches.Pool()
if pool == nil {
f.l.Warn("batch pipeline enabled without an allocated pool; falling back to single-packet reads")
f.listenInLegacyLocked(raw, i)
return
}
for {
packets, err := reader.ReadIntoBatch(pool)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
if isVirtioHeadroomError(err) {
f.l.WithError(err).Warn("Batch reader fell back due to tun headroom issue")
f.listenInLegacyLocked(raw, i)
return
}
f.l.WithError(err).Error("Error while reading outbound packet batch")
os.Exit(2)
}
if len(packets) == 0 {
continue
}
for _, pkt := range packets {
if pkt == nil {
continue
}
if !f.batches.enqueueRx(i, pkt) {
pkt.Release()
}
}
}
}
func (f *Interface) runInsideBatchWorker(i int) {
queue := f.batches.rxQueue(i)
if queue == nil {
return
}
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for pkt := range queue {
if pkt == nil {
continue
}
f.consumeInsidePacket(pkt.Payload(), fwPacket, nb, out, i, conntrackCache.Get(f.l))
pkt.Release()
}
}
func (f *Interface) runSendQueue(i int) {
queue := f.batches.txQueue(i)
if queue == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("queue", i).Debug("tx queue not initialized; batching disabled for writer")
}
return
}
writer := f.writerForIndex(i)
if writer == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("queue", i).Debug("no UDP writer for batch queue")
}
return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("queue", i).Debug("send queue worker started")
}
defer func() {
if f.l.Level >= logrus.WarnLevel {
f.l.WithField("queue", i).Warn("send queue worker exited")
}
}()
batchCap := f.batches.batchSizeHint()
if batchCap <= 0 {
batchCap = 1
}
gsoLimit := f.effectiveGSOMaxSegments()
if gsoLimit > batchCap {
batchCap = gsoLimit
}
pending := make([]queuedDatagram, 0, batchCap)
var (
flushTimer *time.Timer
flushC <-chan time.Time
)
dispatch := func(reason string, timerFired bool) {
if len(pending) == 0 {
return
}
batch := pending
f.flushAndReleaseBatch(i, writer, batch, reason)
for idx := range batch {
batch[idx] = queuedDatagram{}
}
pending = pending[:0]
if flushTimer != nil {
if !timerFired {
if !flushTimer.Stop() {
select {
case <-flushTimer.C:
default:
}
}
}
flushTimer = nil
flushC = nil
}
}
armTimer := func() {
delay := f.currentBatchFlushInterval()
if delay <= 0 {
dispatch("nogso", false)
return
}
if flushTimer == nil {
flushTimer = time.NewTimer(delay)
flushC = flushTimer.C
}
}
for {
select {
case d := <-queue:
if d.packet == nil {
continue
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"queue": i,
"payload_len": d.packet.Len,
"dest": d.addr,
}).Debug("send queue received packet")
}
pending = append(pending, d)
if gsoLimit > 0 && len(pending) >= gsoLimit {
dispatch("gso", false)
continue
}
if len(pending) >= cap(pending) {
dispatch("cap", false)
continue
}
armTimer()
f.observeUDPQueueLen(i)
case <-flushC:
dispatch("timer", true)
}
}
}
func (f *Interface) runTunWriteQueue(i int) {
queue := f.batches.tunQueue(i)
if queue == nil {
return
}
writer := f.batches.inside
if writer == nil {
return
}
requiredHeadroom := writer.BatchHeadroom()
batchCap := f.batches.batchSizeHint()
if batchCap <= 0 {
batchCap = 1
}
pending := make([]*overlay.Packet, 0, batchCap)
var (
flushTimer *time.Timer
flushC <-chan time.Time
)
flush := func(reason string, timerFired bool) {
if len(pending) == 0 {
return
}
valid := pending[:0]
for idx := range pending {
if !f.ensurePacketHeadroom(&pending[idx], requiredHeadroom, i, reason) {
pending[idx] = nil
continue
}
if pending[idx] != nil {
valid = append(valid, pending[idx])
}
}
if len(valid) > 0 {
if _, err := writer.WriteBatch(valid); err != nil {
f.l.WithError(err).
WithField("queue", i).
WithField("reason", reason).
Warn("Failed to write tun batch")
for _, pkt := range valid {
if pkt != nil {
f.writePacketToTun(i, pkt)
}
}
}
}
pending = pending[:0]
if flushTimer != nil {
if !timerFired {
if !flushTimer.Stop() {
select {
case <-flushTimer.C:
default:
}
}
}
flushTimer = nil
flushC = nil
}
}
armTimer := func() {
delay := f.currentBatchFlushInterval()
if delay <= 0 {
return
}
if flushTimer == nil {
flushTimer = time.NewTimer(delay)
flushC = flushTimer.C
}
}
for {
select {
case pkt := <-queue:
if pkt == nil {
continue
}
if f.ensurePacketHeadroom(&pkt, requiredHeadroom, i, "queue") {
pending = append(pending, pkt)
}
if len(pending) >= cap(pending) {
flush("cap", false)
continue
}
armTimer()
f.observeTunQueueLen(i)
case <-flushC:
flush("timer", true)
}
}
}
func (f *Interface) flushAndReleaseBatch(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
if len(batch) == 0 {
return
}
f.flushDatagrams(index, writer, batch, reason)
for idx := range batch {
if batch[idx].packet != nil {
batch[idx].packet.Release()
batch[idx].packet = nil
}
}
if f.batchUDPFlushCounter != nil {
f.batchUDPFlushCounter.Inc(int64(len(batch)))
}
}
func (f *Interface) flushDatagrams(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
if len(batch) == 0 {
return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"writer": index,
"reason": reason,
"pending": len(batch),
}).Debug("udp batch flush summary")
}
maxSeg := f.effectiveGSOMaxSegments()
if bw, ok := writer.(udp.BatchConn); ok {
chunkCap := maxSeg
if chunkCap <= 0 {
chunkCap = len(batch)
}
chunk := make([]udp.Datagram, 0, chunkCap)
var (
currentAddr netip.AddrPort
segments int
)
flushChunk := func() {
if len(chunk) == 0 {
return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"writer": index,
"segments": len(chunk),
"dest": chunk[0].Addr,
"reason": reason,
"pending_total": len(batch),
}).Debug("flushing UDP batch")
}
if err := bw.WriteBatch(chunk); err != nil {
f.l.WithError(err).
WithField("writer", index).
WithField("reason", reason).
Warn("Failed to write UDP batch")
}
chunk = chunk[:0]
segments = 0
}
for _, item := range batch {
if item.packet == nil || !item.addr.IsValid() {
continue
}
payload := item.packet.Payload()[:item.packet.Len]
if segments == 0 {
currentAddr = item.addr
}
if item.addr != currentAddr || (maxSeg > 0 && segments >= maxSeg) {
flushChunk()
currentAddr = item.addr
}
chunk = append(chunk, udp.Datagram{Payload: payload, Addr: item.addr})
segments++
}
flushChunk()
return
}
for _, item := range batch {
if item.packet == nil || !item.addr.IsValid() {
continue
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithFields(logrus.Fields{
"writer": index,
"reason": reason,
"dest": item.addr,
"segments": 1,
}).Debug("flushing UDP batch")
}
if err := writer.WriteTo(item.packet.Payload()[:item.packet.Len], item.addr); err != nil {
f.l.WithError(err).
WithField("writer", index).
WithField("udpAddr", item.addr).
WithField("reason", reason).
Warn("Failed to write UDP packet")
}
}
}
func (f *Interface) tryQueueDatagram(q int, buf []byte, addr netip.AddrPort) bool {
if !addr.IsValid() || !f.batches.Enabled() {
return false
}
pkt := f.batches.newPacket()
if pkt == nil {
return false
}
payload := pkt.Payload()
if len(payload) < len(buf) {
pkt.Release()
return false
}
copy(payload, buf)
pkt.Len = len(buf)
if f.batches.enqueueTx(q, pkt, addr) {
f.observeUDPQueueLen(q)
return true
}
pkt.Release()
return false
}
func (f *Interface) writerForIndex(i int) udp.Conn {
if i < 0 || i >= len(f.writers) {
return nil
}
return f.writers[i]
}
func (f *Interface) writeImmediate(q int, buf []byte, addr netip.AddrPort, hostinfo *HostInfo) {
writer := f.writerForIndex(q)
if writer == nil {
f.l.WithField("udpAddr", addr).
WithField("writer", q).
Error("Failed to write outgoing packet: no writer available")
return
}
if err := writer.WriteTo(buf, addr); err != nil {
hostinfo.logger(f.l).
WithError(err).
WithField("udpAddr", addr).
Error("Failed to write outgoing packet")
}
}
func (f *Interface) tryQueuePacket(q int, pkt *overlay.Packet, addr netip.AddrPort) bool {
if pkt == nil || !addr.IsValid() || !f.batches.Enabled() {
return false
}
if f.batches.enqueueTx(q, pkt, addr) {
f.observeUDPQueueLen(q)
return true
}
return false
}
func (f *Interface) writeImmediatePacket(q int, pkt *overlay.Packet, addr netip.AddrPort, hostinfo *HostInfo) {
if pkt == nil {
return
}
writer := f.writerForIndex(q)
if writer == nil {
f.l.WithField("udpAddr", addr).
WithField("writer", q).
Error("Failed to write outgoing packet: no writer available")
pkt.Release()
return
}
if err := writer.WriteTo(pkt.Payload()[:pkt.Len], addr); err != nil {
hostinfo.logger(f.l).
WithError(err).
WithField("udpAddr", addr).
Error("Failed to write outgoing packet")
}
pkt.Release()
}
func (f *Interface) writePacketToTun(q int, pkt *overlay.Packet) {
if pkt == nil {
return
}
writer := f.readers[q]
if writer == nil {
pkt.Release()
return
}
if bw, ok := writer.(interface {
WriteBatch([]*overlay.Packet) (int, error)
}); ok {
if _, err := bw.WriteBatch([]*overlay.Packet{pkt}); err != nil {
f.l.WithError(err).WithField("queue", q).Warn("Failed to write tun packet via batch writer")
pkt.Release()
}
return
}
if _, err := writer.Write(pkt.Payload()[:pkt.Len]); err != nil {
f.l.WithError(err).Error("Failed to write to tun")
}
pkt.Release()
}
func (f *Interface) clonePacketWithHeadroom(pkt *overlay.Packet, required int) *overlay.Packet {
if pkt == nil {
return nil
}
payload := pkt.Payload()[:pkt.Len]
if len(payload) == 0 && required <= 0 {
return pkt
}
pool := f.batches.Pool()
if pool != nil {
if clone := pool.Get(); clone != nil {
if len(clone.Payload()) >= len(payload) {
clone.Len = copy(clone.Payload(), payload)
pkt.Release()
return clone
}
clone.Release()
}
}
if required < 0 {
required = 0
}
buf := make([]byte, required+len(payload))
n := copy(buf[required:], payload)
pkt.Release()
return &overlay.Packet{
Buf: buf,
Offset: required,
Len: n,
}
}
func (f *Interface) observeUDPQueueLen(i int) {
if f.batchUDPQueueGauge == nil {
return
}
f.batchUDPQueueGauge.Update(int64(f.batches.txQueueLen(i)))
}
func (f *Interface) observeTunQueueLen(i int) {
if f.batchTunQueueGauge == nil {
return
}
f.batchTunQueueGauge.Update(int64(f.batches.tunQueueLen(i)))
}
func (f *Interface) currentBatchFlushInterval() time.Duration {
if v := f.batchFlushInterval.Load(); v > 0 {
return time.Duration(v)
}
return 0
}
func (f *Interface) ensurePacketHeadroom(pkt **overlay.Packet, required int, queue int, reason string) bool {
p := *pkt
if p == nil {
return false
}
if required <= 0 || p.Offset >= required {
return true
}
clone := f.clonePacketWithHeadroom(p, required)
if clone == nil {
f.l.WithFields(logrus.Fields{
"queue": queue,
"reason": reason,
}).Warn("dropping packet lacking tun headroom")
return false
}
*pkt = clone
return true
}
func isVirtioHeadroomError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "headroom") || strings.Contains(msg, "virtio")
}
func (f *Interface) effectiveGSOMaxSegments() int {
max := f.gsoMaxSegments
if max <= 0 {
max = defaultGSOMaxSegments
}
if max > maxKernelGSOSegments {
max = maxKernelGSOSegments
}
if !f.enableGSO {
return 1
}
return max
}
type udpOffloadConfigurator interface {
ConfigureOffload(enableGSO, enableGRO bool, maxSegments int)
}
func (f *Interface) applyOffloadConfig(enableGSO, enableGRO bool, maxSegments int) {
if maxSegments <= 0 {
maxSegments = defaultGSOMaxSegments
}
if maxSegments > maxKernelGSOSegments {
maxSegments = maxKernelGSOSegments
}
f.enableGSO = enableGSO
f.enableGRO = enableGRO
f.gsoMaxSegments = maxSegments
for _, writer := range f.writers {
if cfg, ok := writer.(udpOffloadConfigurator); ok {
cfg.ConfigureOffload(enableGSO, enableGRO, maxSegments)
}
}
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError) c.RegisterReloadCallback(f.reloadSendRecvError)
@@ -411,6 +1062,42 @@ func (f *Interface) reloadMisc(c *config.C) {
f.reQueryWait.Store(int64(n)) f.reQueryWait.Store(int64(n))
f.l.Info("timers.requery_wait_duration has changed") f.l.Info("timers.requery_wait_duration has changed")
} }
if c.HasChanged("listen.gso_flush_timeout") {
d := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
if d < 0 {
d = 0
}
f.batchFlushInterval.Store(int64(d))
f.l.WithField("duration", d).Info("listen.gso_flush_timeout has changed")
} else if c.HasChanged("batch.flush_interval") {
d := c.GetDuration("batch.flush_interval", defaultGSOFlushInterval)
if d < 0 {
d = 0
}
f.batchFlushInterval.Store(int64(d))
f.l.WithField("duration", d).Warn("batch.flush_interval is deprecated; use listen.gso_flush_timeout")
}
if c.HasChanged("batch.queue_depth") {
n := c.GetInt("batch.queue_depth", f.batchQueueDepth)
if n != f.batchQueueDepth {
f.batchQueueDepth = n
f.l.Warn("batch.queue_depth changes require a restart to take effect")
}
}
if c.HasChanged("listen.enable_gso") || c.HasChanged("listen.enable_gro") || c.HasChanged("listen.gso_max_segments") {
enableGSO := c.GetBool("listen.enable_gso", f.enableGSO)
enableGRO := c.GetBool("listen.enable_gro", f.enableGRO)
maxSeg := c.GetInt("listen.gso_max_segments", f.gsoMaxSegments)
f.applyOffloadConfig(enableGSO, enableGRO, maxSeg)
f.l.WithFields(logrus.Fields{
"enableGSO": enableGSO,
"enableGRO": enableGRO,
"gsoMaxSegments": maxSeg,
}).Info("listen GSO/GRO configuration updated")
}
} }
func (f *Interface) emitStats(ctx context.Context, i time.Duration) { func (f *Interface) emitStats(ctx context.Context, i time.Duration) {

View File

@@ -360,8 +360,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
} }
if !lh.myVpnNetworksTable.Contains(addr) { if !lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
} }
out[i] = addr out[i] = addr
} }
@@ -432,8 +431,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
} }
if !lh.myVpnNetworksTable.Contains(vpnAddr) { if !lh.myVpnNetworksTable.Contains(vpnAddr) {
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
} }
vals, ok := v.([]any) vals, ok := v.([]any)
@@ -1339,19 +1337,12 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
} }
} }
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts { for _, a := range n.Details.V4AddrPorts {
b := protoV4AddrPortToNetAddrPort(a) punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
} }
for _, a := range n.Details.V6AddrPorts { for _, a := range n.Details.V6AddrPorts {
b := protoV6AddrPortToNetAddrPort(a) punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
} }
// This sends a nebula test packet to the host trying to contact us. In the case // This sends a nebula test packet to the host trying to contact us. In the case

View File

@@ -14,7 +14,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func TestOldIPv4Only(t *testing.T) { func TestOldIPv4Only(t *testing.T) {

63
main.go
View File

@@ -5,8 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime/debug" "runtime"
"strings"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -15,7 +14,7 @@ import (
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
type m = map[string]any type m = map[string]any
@@ -29,10 +28,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
}() }()
if buildVersion == "" {
buildVersion = moduleVersion()
}
l := logger l := logger
l.Formatter = &logrus.TextFormatter{ l.Formatter = &logrus.TextFormatter{
FullTimestamp: true, FullTimestamp: true,
@@ -81,8 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c) sshStart, err = configSSH(l, ssh, c)
if err != nil { if err != nil {
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
sshStart = nil
} }
} }
@@ -150,6 +144,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
// set up our UDP listener // set up our UDP listener
udpConns := make([]udp.Conn, routines) udpConns := make([]udp.Conn, routines)
port := c.GetInt("listen.port", 0) port := c.GetInt("listen.port", 0)
enableGSO := c.GetBool("listen.enable_gso", true)
enableGRO := c.GetBool("listen.enable_gro", true)
gsoMaxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
if gsoMaxSegments <= 0 {
gsoMaxSegments = defaultGSOMaxSegments
}
if gsoMaxSegments > maxKernelGSOSegments {
gsoMaxSegments = maxKernelGSOSegments
}
gsoFlushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
if gsoFlushTimeout < 0 {
gsoFlushTimeout = 0
}
batchQueueDepth := c.GetInt("batch.queue_depth", 0)
if !configTest { if !configTest {
rawListenHost := c.GetString("listen.host", "0.0.0.0") rawListenHost := c.GetString("listen.host", "0.0.0.0")
@@ -169,13 +177,27 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
listenHost = ips[0].Unmap() listenHost = ips[0].Unmap()
} }
useWGDefault := runtime.GOOS == "linux"
useWG := c.GetBool("listen.use_wireguard_stack", useWGDefault)
var mkListener func(*logrus.Logger, netip.Addr, int, bool, int) (udp.Conn, error)
if useWG {
mkListener = udp.NewWireguardListener
} else {
mkListener = udp.NewListener
}
for i := 0; i < routines; i++ { for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil { if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
} }
udpServer.ReloadConfig(c) udpServer.ReloadConfig(c)
if cfg, ok := udpServer.(interface {
ConfigureOffload(bool, bool, int)
}); ok {
cfg.ConfigureOffload(enableGSO, enableGRO, gsoMaxSegments)
}
udpConns[i] = udpServer udpConns[i] = udpServer
// If port is dynamic, discover it before the next pass through the for loop // If port is dynamic, discover it before the next pass through the for loop
@@ -243,12 +265,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
DropMulticast: c.GetBool("tun.drop_multicast", false), DropMulticast: c.GetBool("tun.drop_multicast", false),
EnableGSO: enableGSO,
EnableGRO: enableGRO,
GSOMaxSegments: gsoMaxSegments,
routines: routines, routines: routines,
MessageMetrics: messageMetrics, MessageMetrics: messageMetrics,
version: buildVersion, version: buildVersion,
relayManager: NewRelayManager(ctx, l, hostMap, c), relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy, punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
BatchFlushInterval: gsoFlushTimeout,
BatchQueueDepth: batchQueueDepth,
l: l, l: l,
} }
@@ -260,6 +287,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
ifce.writers = udpConns ifce.writers = udpConns
ifce.applyOffloadConfig(enableGSO, enableGRO, gsoMaxSegments)
lightHouse.ifce = ifce lightHouse.ifce = ifce
ifce.RegisterConfigChangeCallbacks(c) ifce.RegisterConfigChangeCallbacks(c)
@@ -302,18 +330,3 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
connManager.Start, connManager.Start,
}, nil }, nil
} }
func moduleVersion() string {
info, ok := debug.ReadBuildInfo()
if !ok {
return ""
}
for _, dep := range info.Deps {
if dep.Path == "github.com/slackhq/nebula" {
return strings.TrimPrefix(dep.Version, "v")
}
}
return ""
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@@ -19,7 +20,7 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache *firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -61,7 +62,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, ip, h.RemoteIndex) {
return return
} }
case header.MessageRelay: case header.MessageRelay:
@@ -465,23 +466,45 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil return out, nil
} }
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache *firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool {
var err error var (
err error
pkt *overlay.Packet
)
if f.batches.tunQueue(q) != nil {
pkt = f.batches.newPacket()
if pkt != nil {
out = pkt.Payload()[:0]
}
}
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil { if err != nil {
if pkt != nil {
pkt.Release()
}
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
if addr.IsValid() {
f.maybeSendRecvError(addr, recvIndex)
}
return false return false
} }
err = newPacket(out, true, fwPacket) err = newPacket(out, true, fwPacket)
if err != nil { if err != nil {
if pkt != nil {
pkt.Release()
}
hostinfo.logger(f.l).WithError(err).WithField("packet", out). hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet") Warnf("Error while validating inbound packet")
return false return false
} }
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
if pkt != nil {
pkt.Release()
}
hostinfo.logger(f.l).WithField("fwPacket", fwPacket). hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet") Debugln("dropping out of window packet")
return false return false
@@ -489,6 +512,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil { if dropReason != nil {
if pkt != nil {
pkt.Release()
}
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in // This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
@@ -501,8 +527,17 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
} }
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out) if pkt != nil {
if err != nil { pkt.Len = len(out)
if f.batches.enqueueTun(q, pkt) {
f.observeTunQueueLen(q)
return true
}
f.writePacketToTun(q, pkt)
return true
}
if _, err = f.readers[q].Write(out); err != nil {
f.l.WithError(err).Error("Failed to write to tun") f.l.WithError(err).Error("Failed to write to tun")
} }
return true return true

View File

@@ -3,6 +3,7 @@ package overlay
import ( import (
"io" "io"
"net/netip" "net/netip"
"sync"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -13,6 +14,86 @@ type Device interface {
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (io.ReadWriteCloser, error)
} }
// Packet represents a single packet buffer with optional headroom to carry
// metadata (for example virtio-net headers).
type Packet struct {
Buf []byte
Offset int
Len int
release func()
}
func (p *Packet) Payload() []byte {
return p.Buf[p.Offset : p.Offset+p.Len]
}
func (p *Packet) Reset() {
p.Len = 0
p.Offset = 0
p.release = nil
}
func (p *Packet) Release() {
if p.release != nil {
p.release()
p.release = nil
}
}
func (p *Packet) Capacity() int {
return len(p.Buf) - p.Offset
}
// PacketPool manages reusable buffers with headroom.
type PacketPool struct {
headroom int
blksz int
pool sync.Pool
}
func NewPacketPool(headroom, payload int) *PacketPool {
p := &PacketPool{headroom: headroom, blksz: headroom + payload}
p.pool.New = func() any {
buf := make([]byte, p.blksz)
return &Packet{Buf: buf, Offset: headroom}
}
return p
}
func (p *PacketPool) Get() *Packet {
pkt := p.pool.Get().(*Packet)
pkt.Offset = p.headroom
pkt.Len = 0
pkt.release = func() { p.put(pkt) }
return pkt
}
func (p *PacketPool) put(pkt *Packet) {
pkt.Reset()
p.pool.Put(pkt)
}
// BatchReader allows reading multiple packets into a shared pool with
// preallocated headroom (e.g. virtio-net headers).
type BatchReader interface {
ReadIntoBatch(pool *PacketPool) ([]*Packet, error)
}
// BatchWriter writes a slice of packets that carry their own metadata.
type BatchWriter interface {
WriteBatch(packets []*Packet) (int, error)
}
// BatchCapableDevice describes a device that can efficiently read and write
// batches of packets with virtio headroom.
type BatchCapableDevice interface {
Device
BatchReader
BatchWriter
BatchHeadroom() int
BatchPayloadCap() int
BatchSize() int
}

View File

@@ -95,10 +95,6 @@ func (t *tun) Name() string {
return "android" return "android"
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android") return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
} }

View File

@@ -549,10 +549,6 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }

View File

@@ -105,10 +105,6 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (t *disabledTun) SupportsMultiqueue() bool {
return true
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil return t, nil
} }

View File

@@ -450,10 +450,6 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
} }

View File

@@ -151,10 +151,6 @@ func (t *tun) Name() string {
return "iOS" return "iOS"
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
} }

View File

@@ -9,6 +9,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"runtime"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -19,6 +20,7 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
wgtun "github.com/slackhq/nebula/wgstack/tun"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -33,6 +35,7 @@ type tun struct {
TXQueueLen int TXQueueLen int
deviceIndex int deviceIndex int
ioctlFd uintptr ioctlFd uintptr
wgDevice wgtun.Device
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -68,7 +71,9 @@ type ifreqQLEN struct {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks) useWGDefault := runtime.GOOS == "linux"
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -113,7 +118,9 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
name := strings.Trim(string(req.Name[:]), "\x00") name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks) useWGDefault := runtime.GOOS == "linux"
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -123,16 +130,45 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil return t, nil
} }
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, useWireguard bool) (*tun, error) {
var (
rw io.ReadWriteCloser = file
fd = int(file.Fd())
wgDev wgtun.Device
)
if useWireguard {
dev, err := wgtun.CreateTUNFromFile(file, c.GetInt("tun.mtu", DefaultMTU))
if err != nil {
return nil, fmt.Errorf("failed to initialize wireguard tun device: %w", err)
}
wgDev = dev
rw = newWireguardTunIO(dev, c.GetInt("tun.mtu", DefaultMTU))
fd = int(dev.File().Fd())
}
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: rw,
fd: int(file.Fd()), fd: fd,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
l: l, l: l,
} }
if wgDev != nil {
t.wgDevice = wgDev
}
if wgDev != nil {
// replace ioctl fd with device file descriptor to keep route management working
file = wgDev.File()
t.fd = int(file.Fd())
t.ioctlFd = file.Fd()
}
if t.ioctlFd == 0 {
t.ioctlFd = file.Fd()
}
err := t.reload(c, true) err := t.reload(c, true)
if err != nil { if err != nil {
@@ -216,10 +252,6 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) SupportsMultiqueue() bool {
return true
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
@@ -586,42 +618,48 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
} }
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device) link, err := netlink.LinkByName(t.Device)
if err != nil { if err != nil {
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways return gateways
} }
// If this route is relevant to our interface and there is a gateway then add it // If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index { if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := getGatewayAddr(r.Gw, r.Via) gwAddr, ok := netip.AddrFromSlice(r.Gw)
if ok { if !ok {
if t.isGatewayInVpnNetworks(gwAddr) { t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else { } else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
} }
} }
for _, p := range r.MultiPath { for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it // If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index { if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := getGatewayAddr(p.Gw, p.Via) gwAddr, ok := netip.AddrFromSlice(p.Gw)
if ok { if !ok {
if t.isGatewayInVpnNetworks(gwAddr) { t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else { } else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
// p.Hops+1 = weight of the route
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
} }
} }
} }
@@ -630,27 +668,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
return gateways return gateways
} }
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
// Try to use the old RTA_GATEWAY first
gwAddr, ok := netip.AddrFromSlice(gw)
if !ok {
// Fallback to the new RTA_VIA
rVia, ok := via.(*netlink.Via)
if ok {
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
}
}
if gwAddr.IsValid() {
gwAddr = gwAddr.Unmap()
return gwAddr, true
}
return netip.Addr{}, false
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) { func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route) gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 { if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required. // No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways") t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
@@ -693,6 +714,14 @@ func (t *tun) Close() error {
_ = t.ReadWriteCloser.Close() _ = t.ReadWriteCloser.Close()
} }
if t.wgDevice != nil {
_ = t.wgDevice.Close()
if t.ioctlFd > 0 {
// underlying fd already closed by the device
t.ioctlFd = 0
}
}
if t.ioctlFd > 0 { if t.ioctlFd > 0 {
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close() _ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
} }

View File

@@ -0,0 +1,56 @@
//go:build linux && !android && !e2e_testing
package overlay
import "fmt"
func (t *tun) batchIO() (*wireguardTunIO, bool) {
io, ok := t.ReadWriteCloser.(*wireguardTunIO)
return io, ok
}
func (t *tun) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
io, ok := t.batchIO()
if !ok {
return nil, fmt.Errorf("wireguard batch I/O not enabled")
}
return io.ReadIntoBatch(pool)
}
func (t *tun) WriteBatch(packets []*Packet) (int, error) {
io, ok := t.batchIO()
if ok {
return io.WriteBatch(packets)
}
for _, pkt := range packets {
if pkt == nil {
continue
}
if _, err := t.Write(pkt.Payload()[:pkt.Len]); err != nil {
return 0, err
}
pkt.Release()
}
return len(packets), nil
}
func (t *tun) BatchHeadroom() int {
if io, ok := t.batchIO(); ok {
return io.BatchHeadroom()
}
return 0
}
func (t *tun) BatchPayloadCap() int {
if io, ok := t.batchIO(); ok {
return io.BatchPayloadCap()
}
return 0
}
func (t *tun) BatchSize() int {
if io, ok := t.batchIO(); ok {
return io.BatchSize()
}
return 1
}

View File

@@ -390,10 +390,6 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
} }

View File

@@ -310,10 +310,6 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
} }

View File

@@ -132,10 +132,6 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func (t *TestTun) SupportsMultiqueue() bool {
return false
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented") return nil, fmt.Errorf("TODO: multiqueue not implemented")
} }

View File

@@ -234,10 +234,6 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0) return t.tun.Write(b, 0)
} }
func (t *winTun) SupportsMultiqueue() bool {
return false
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
} }

View File

@@ -46,10 +46,6 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)} return routing.Gateways{routing.NewGateway(ip, 1)}
} }
func (d *UserDevice) SupportsMultiqueue() bool {
return true
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil return d, nil
} }

View File

@@ -0,0 +1,220 @@
//go:build linux && !android && !e2e_testing
package overlay
import (
"fmt"
"sync"
wgtun "github.com/slackhq/nebula/wgstack/tun"
)
type wireguardTunIO struct {
dev wgtun.Device
mtu int
batchSize int
readMu sync.Mutex
readBuffers [][]byte
readLens []int
legacyBuf []byte
writeMu sync.Mutex
writeBuf []byte
writeWrap [][]byte
writeBuffers [][]byte
}
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
batch := dev.BatchSize()
if batch <= 0 {
batch = 1
}
if mtu <= 0 {
mtu = DefaultMTU
}
return &wireguardTunIO{
dev: dev,
mtu: mtu,
batchSize: batch,
readLens: make([]int, batch),
legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
writeWrap: make([][]byte, 1),
}
}
func (w *wireguardTunIO) Read(p []byte) (int, error) {
w.readMu.Lock()
defer w.readMu.Unlock()
bufs := w.readBuffers
if len(bufs) == 0 {
bufs = [][]byte{w.legacyBuf}
w.readBuffers = bufs
}
n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, nil
}
length := w.readLens[0]
copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
return length, nil
}
func (w *wireguardTunIO) Write(p []byte) (int, error) {
if len(p) > w.mtu {
return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
}
w.writeMu.Lock()
defer w.writeMu.Unlock()
buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
buf[i] = 0
}
copy(buf[wgtun.VirtioNetHdrLen:], p)
w.writeWrap[0] = buf
n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
if err != nil {
return n, err
}
return len(p), nil
}
func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
if pool == nil {
return nil, fmt.Errorf("wireguard tun: packet pool is nil")
}
w.readMu.Lock()
defer w.readMu.Unlock()
if len(w.readBuffers) < w.batchSize {
w.readBuffers = make([][]byte, w.batchSize)
}
if len(w.readLens) < w.batchSize {
w.readLens = make([]int, w.batchSize)
}
packets := make([]*Packet, w.batchSize)
requiredHeadroom := w.BatchHeadroom()
requiredPayload := w.BatchPayloadCap()
headroom := 0
for i := 0; i < w.batchSize; i++ {
pkt := pool.Get()
if pkt == nil {
releasePackets(packets[:i])
return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
}
if pkt.Capacity() < requiredPayload {
pkt.Release()
releasePackets(packets[:i])
return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
}
if i == 0 {
headroom = pkt.Offset
if headroom < requiredHeadroom {
pkt.Release()
releasePackets(packets[:i])
return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
}
} else if pkt.Offset != headroom {
pkt.Release()
releasePackets(packets[:i])
return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
}
packets[i] = pkt
w.readBuffers[i] = pkt.Buf
}
n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
if err != nil {
releasePackets(packets)
return nil, err
}
if n == 0 {
releasePackets(packets)
return nil, nil
}
for i := 0; i < n; i++ {
packets[i].Len = w.readLens[i]
}
for i := n; i < w.batchSize; i++ {
packets[i].Release()
packets[i] = nil
}
return packets[:n], nil
}
func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
if len(packets) == 0 {
return 0, nil
}
requiredHeadroom := w.BatchHeadroom()
offset := packets[0].Offset
if offset < requiredHeadroom {
releasePackets(packets)
return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
}
for _, pkt := range packets {
if pkt == nil {
continue
}
if pkt.Offset != offset {
releasePackets(packets)
return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
}
limit := pkt.Offset + pkt.Len
if limit > len(pkt.Buf) {
releasePackets(packets)
return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
}
}
w.writeMu.Lock()
defer w.writeMu.Unlock()
if len(w.writeBuffers) < len(packets) {
w.writeBuffers = make([][]byte, len(packets))
}
for i, pkt := range packets {
if pkt == nil {
w.writeBuffers[i] = nil
continue
}
limit := pkt.Offset + pkt.Len
w.writeBuffers[i] = pkt.Buf[:limit]
}
n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
if err != nil {
return n, err
}
releasePackets(packets)
return n, nil
}
func (w *wireguardTunIO) BatchHeadroom() int {
return wgtun.VirtioNetHdrLen
}
func (w *wireguardTunIO) BatchPayloadCap() int {
return w.mtu
}
func (w *wireguardTunIO) BatchSize() int {
return w.batchSize
}
func (w *wireguardTunIO) Close() error {
return nil
}
func releasePackets(pkts []*Packet) {
for _, pkt := range pkts {
if pkt != nil {
pkt.Release()
}
}
}

10
pki.go
View File

@@ -523,13 +523,9 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
} }
bl := c.GetStringSlice("pki.blocklist", []string{}) for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
if len(bl) > 0 { l.WithField("fingerprint", fp).Info("Blocklisting cert")
for _, fp := range bl { caPool.BlocklistFingerprint(fp)
caPool.BlocklistFingerprint(fp)
}
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
} }
return caPool, nil return caPool, nil

View File

@@ -16,8 +16,8 @@ import (
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"go.yaml.in/yaml/v3"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
) )
type m = map[string]any type m = map[string]any

View File

@@ -34,10 +34,6 @@ func (NoopTun) Write([]byte) (int, error) {
return 0, nil return 0, nil
} }
func (NoopTun) SupportsMultiqueue() bool {
return false
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, errors.New("unsupported") return nil, errors.New("unsupported")
} }

View File

@@ -19,10 +19,21 @@ type Conn interface {
ListenOut(r EncReader) ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
SupportsMultipleReaders() bool
Close() error Close() error
} }
// Datagram represents a UDP payload destined to a specific address.
type Datagram struct {
Payload []byte
Addr netip.AddrPort
}
// BatchConn can send multiple datagrams in one syscall.
type BatchConn interface {
Conn
WriteBatch(pkts []Datagram) error
}
type NoopConn struct{} type NoopConn struct{}
func (NoopConn) Rebind() error { func (NoopConn) Rebind() error {
@@ -34,9 +45,6 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
func (NoopConn) ListenOut(_ EncReader) { func (NoopConn) ListenOut(_ EncReader) {
return return
} }
func (NoopConn) SupportsMultipleReaders() bool {
return false
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }

View File

@@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
return ErrInvalidIPv6RemoteForSocket return ErrInvalidIPv6RemoteForSocket
} }
var rsa unix.RawSockaddrInet4 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As4() rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa) sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4 addrLen = syscall.SizeofSockaddrInet4
@@ -184,10 +184,6 @@ func (u *StdConn) ListenOut(r EncReader) {
} }
} }
func (u *StdConn) SupportsMultipleReaders() bool {
return false
}
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
var err error var err error
if u.isV4 { if u.isV4 {

View File

@@ -85,7 +85,3 @@ func (u *GenericConn) ListenOut(r EncReader) {
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
} }
} }
func (u *GenericConn) SupportsMultipleReaders() bool {
return false
}

View File

@@ -72,10 +72,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
} }
func (u *StdConn) SupportsMultipleReaders() bool {
return true
}
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
return nil return nil
} }
@@ -314,31 +310,51 @@ func (u *StdConn) Close() error {
} }
func NewUDPStatsEmitter(udpConns []Conn) func() { func NewUDPStatsEmitter(udpConns []Conn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges if len(udpConns) == 0 {
var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge return func() {}
}
type statsProvider struct {
index int
conn *StdConn
}
providers := make([]statsProvider, 0, len(udpConns))
for i, c := range udpConns {
if sc, ok := c.(*StdConn); ok {
providers = append(providers, statsProvider{index: i, conn: sc})
}
}
if len(providers) == 0 {
return func() {}
}
var meminfo [unix.SK_MEMINFO_VARS]uint32 var meminfo [unix.SK_MEMINFO_VARS]uint32
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil { if err := providers[0].conn.getMemInfo(&meminfo); err != nil {
udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) return func() {}
for i := range udpConns { }
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil), udpGauges := make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(providers))
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil), for i, provider := range providers {
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil), udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", provider.index), nil),
} metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", provider.index), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", provider.index), nil),
} }
} }
return func() { return func() {
for i, gauges := range udpGauges { for i, provider := range providers {
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { if err := provider.conn.getMemInfo(&meminfo); err == nil {
for j := 0; j < unix.SK_MEMINFO_VARS; j++ { for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
gauges[j].Update(int64(meminfo[j])) udpGauges[i][j].Update(int64(meminfo[j]))
} }
} }
} }

View File

@@ -315,10 +315,6 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
} }
func (u *RIOConn) SupportsMultipleReaders() bool {
return false
}
func (u *RIOConn) Rebind() error { func (u *RIOConn) Rebind() error {
return nil return nil
} }

View File

@@ -127,10 +127,6 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil return u.Addr, nil
} }
func (u *TesterConn) SupportsMultipleReaders() bool {
return false
}
func (u *TesterConn) Rebind() error { func (u *TesterConn) Rebind() error {
return nil return nil
} }

225
udp/wireguard_conn_linux.go Normal file
View File

@@ -0,0 +1,225 @@
//go:build linux && !android && !e2e_testing
package udp
import (
"errors"
"net"
"net/netip"
"sync"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
wgconn "github.com/slackhq/nebula/wgstack/conn"
)
// WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
type WGConn struct {
l *logrus.Logger
bind *wgconn.StdNetBind
recvers []wgconn.ReceiveFunc
batch int
reqBatch int
localIP netip.Addr
localPort uint16
enableGSO bool
enableGRO bool
gsoMaxSeg int
closed atomic.Bool
closeOnce sync.Once
}
// NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
bind := wgconn.NewStdNetBindForAddr(ip, multi)
recvers, actualPort, err := bind.Open(uint16(port))
if err != nil {
return nil, err
}
if batch <= 0 {
batch = bind.BatchSize()
} else if batch > bind.BatchSize() {
batch = bind.BatchSize()
}
return &WGConn{
l: l,
bind: bind,
recvers: recvers,
batch: batch,
reqBatch: batch,
localIP: ip,
localPort: actualPort,
}, nil
}
func (c *WGConn) Rebind() error {
// WireGuard's bind does not support rebinding in place.
return nil
}
func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
// Fallback to wildcard IPv4 for display purposes.
return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
}
return netip.AddrPortFrom(c.localIP, c.localPort), nil
}
func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
batchSize := c.batch
packets := make([][]byte, batchSize)
for i := range packets {
packets[i] = make([]byte, MTU)
}
sizes := make([]int, batchSize)
endpoints := make([]wgconn.Endpoint, batchSize)
for {
if c.closed.Load() {
return
}
n, err := fn(packets, sizes, endpoints)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
if c.l != nil {
c.l.WithError(err).Debug("wireguard UDP listener receive error")
}
continue
}
for i := 0; i < n; i++ {
if sizes[i] == 0 {
continue
}
stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
if !ok {
if c.l != nil {
c.l.Warn("wireguard UDP listener received unexpected endpoint type")
}
continue
}
addr := stdEp.AddrPort
r(addr, packets[i][:sizes[i]])
endpoints[i] = nil
}
}
}
func (c *WGConn) ListenOut(r EncReader) {
for _, fn := range c.recvers {
go c.listen(fn, r)
}
}
func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
if len(b) == 0 {
return nil
}
if c.closed.Load() {
return net.ErrClosed
}
ep := &wgconn.StdNetEndpoint{AddrPort: addr}
return c.bind.Send([][]byte{b}, ep)
}
func (c *WGConn) WriteBatch(datagrams []Datagram) error {
if len(datagrams) == 0 {
return nil
}
if c.closed.Load() {
return net.ErrClosed
}
max := c.batch
if max <= 0 {
max = len(datagrams)
if max == 0 {
max = 1
}
}
bufs := make([][]byte, 0, max)
var (
current netip.AddrPort
endpoint *wgconn.StdNetEndpoint
haveAddr bool
)
flush := func() error {
if len(bufs) == 0 || endpoint == nil {
bufs = bufs[:0]
return nil
}
err := c.bind.Send(bufs, endpoint)
bufs = bufs[:0]
return err
}
for _, d := range datagrams {
if len(d.Payload) == 0 || !d.Addr.IsValid() {
continue
}
if !haveAddr || d.Addr != current {
if err := flush(); err != nil {
return err
}
current = d.Addr
endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
haveAddr = true
}
bufs = append(bufs, d.Payload)
if len(bufs) >= max {
if err := flush(); err != nil {
return err
}
}
}
return flush()
}
func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
c.enableGSO = enableGSO
c.enableGRO = enableGRO
if maxSegments <= 0 {
maxSegments = 1
} else if maxSegments > wgconn.IdealBatchSize {
maxSegments = wgconn.IdealBatchSize
}
c.gsoMaxSeg = maxSegments
effectiveBatch := c.reqBatch
if enableGSO && c.bind != nil {
bindBatch := c.bind.BatchSize()
if effectiveBatch < bindBatch {
if c.l != nil {
c.l.WithFields(logrus.Fields{
"requested": c.reqBatch,
"effective": bindBatch,
}).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
}
effectiveBatch = bindBatch
}
}
c.batch = effectiveBatch
if c.l != nil {
c.l.WithFields(logrus.Fields{
"enableGSO": enableGSO,
"enableGRO": enableGRO,
"gsoMaxSegments": maxSegments,
}).Debug("configured wireguard UDP offload")
}
}
func (c *WGConn) ReloadConfig(*config.C) {
// WireGuard bind currently does not expose runtime configuration knobs.
}
func (c *WGConn) Close() error {
var err error
c.closeOnce.Do(func() {
c.closed.Store(true)
err = c.bind.Close()
})
return err
}

View File

@@ -0,0 +1,15 @@
//go:build !linux || android || e2e_testing
package udp
import (
"fmt"
"net/netip"
"github.com/sirupsen/logrus"
)
// NewWireguardListener is only available on Linux builds.
func NewWireguardListener(*logrus.Logger, netip.Addr, int, bool, int) (Conn, error) {
return nil, fmt.Errorf("wireguard experimental UDP listener is only supported on Linux")
}

539
wgstack/conn/bind_std.go Normal file
View File

@@ -0,0 +1,539 @@
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"context"
"errors"
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"syscall"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
var (
_ Bind = (*StdNetBind)(nil)
)
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
// (see bind_windows.go), it may fall back to StdNetBind.
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
// these three fields are not guarded by mu
udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
blackhole4 bool
blackhole6 bool
listenAddr4 string
listenAddr6 string
bindV4 bool
bindV6 bool
reusePort bool
}
func newStdNetBind() *StdNetBind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
bindV4: true,
bindV6: true,
reusePort: false,
}
}
// NewStdNetBind creates a bind that listens on all interfaces.
func NewStdNetBind() *StdNetBind {
return newStdNetBind()
}
// NewStdNetBindForAddr creates a bind that listens on a specific address.
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
// IPv6 socket will be created.
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
b := newStdNetBind()
if addr.IsValid() {
if addr.IsUnspecified() {
// keep dual-stack defaults with empty listen addresses
} else if addr.Is4() {
b.listenAddr4 = addr.Unmap().String()
b.bindV4 = true
b.bindV6 = false
} else {
b.listenAddr6 = addr.Unmap().String()
b.bindV6 = true
b.bindV4 = false
}
}
b.reusePort = reusePort
return b
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if supported.
src struct {
netip.Addr
ifidx int32
}
}
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &StdNetEndpoint{
AddrPort: e,
}, nil
}
func (e *StdNetEndpoint) ClearSrc() {
e.src.ifidx = 0
e.src.Addr = netip.Addr{}
}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return e.src.Addr
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return e.src.ifidx
}
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}
func (e *StdNetEndpoint) SrcToString() string {
return e.src.Addr.String()
}
func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
lc := listenConfig()
if s.reusePort {
base := lc.Control
lc.Control = func(network, address string, c syscall.RawConn) error {
if base != nil {
if err := base(network, address, c); err != nil {
return err
}
}
return c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
}
}
addr := ":" + strconv.Itoa(port)
if host != "" {
addr = net.JoinHostPort(host, strconv.Itoa(port))
}
conn, err := lc.ListenPacket(context.Background(), network, addr)
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn.(*net.UDPConn), uaddr.Port, nil
}
func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
if !s.bindV4 {
return nil, nil, port, nil
}
host := s.listenAddr4
conn, actualPort, err := s.listenNet("udp4", host, port)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, nil, port, nil
}
return nil, nil, port, err
}
if runtime.GOOS != "linux" {
return conn, nil, actualPort, nil
}
pc := ipv4.NewPacketConn(conn)
return conn, pc, actualPort, nil
}
func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
if !s.bindV6 {
return nil, nil, port, nil
}
host := s.listenAddr6
conn, actualPort, err := s.listenNet("udp6", host, port)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, nil, port, nil
}
return nil, nil, port, err
}
if runtime.GOOS != "linux" {
return conn, nil, actualPort, nil
}
pc := ipv6.NewPacketConn(conn)
return conn, pc, actualPort, nil
}
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
var err error
var tries int
if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn *net.UDPConn
var v6conn *net.UDPConn
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn
v4conn, v4pc, port, err = s.openIPv4(port)
if err != nil {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
v6conn, v6pc, port, err = s.openIPv6(port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
if v4conn != nil {
v4conn.Close()
}
tries++
goto again
}
if err != nil {
if v4conn != nil {
v4conn.Close()
}
return nil, 0, err
}
var fns []ReceiveFunc
if v4conn != nil {
s.ipv4 = v4conn
if v4pc != nil {
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
}
if v6conn != nil {
s.ipv6 = v6conn
if v6pc != nil {
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" && pc != nil {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv6MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" && pc != nil {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" {
return IdealBatchSize
}
return 1
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
s.ipv4PC = nil
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
s.ipv6PC = nil
}
s.blackhole4 = false
s.blackhole6 = false
if err1 != nil {
return err1
}
return err2
}
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
var (
pc4 *ipv4.PacketConn
pc6 *ipv6.PacketConn
)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
pc6 = s.ipv6PC
is6 = true
} else {
pc4 = s.ipv4PC
}
s.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
if is6 {
return s.send6(conn, pc6, endpoint, bufs)
} else {
return s.send4(conn, pc4, endpoint, bufs)
}
}
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
if runtime.GOOS == "linux" && pc != nil {
for {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
for j := start; j < len(bufs); j++ {
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
if werr != nil {
err = werr
break
}
}
}
break
}
if n == len((*msgs)[start:len(bufs)]) {
break
}
start += n
}
} else {
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
if err != nil {
break
}
}
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
if runtime.GOOS == "linux" && pc != nil {
for {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
for j := start; j < len(bufs); j++ {
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
if werr != nil {
err = werr
break
}
}
}
break
}
if n == len((*msgs)[start:len(bufs)]) {
break
}
start += n
}
} else {
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
if err != nil {
break
}
}
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
}

131
wgstack/conn/conn.go Normal file
View File

@@ -0,0 +1,131 @@
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"errors"
"fmt"
"net/netip"
"reflect"
"runtime"
"strings"
)
const (
IdealBatchSize = 128 // maximum number of packets handled per read and write
)
// A ReceiveFunc receives at least one packet from the network and writes them
// into packets. On a successful read it returns the number of elements of
// sizes, packets, and endpoints that should be evaluated. Some elements of
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
// and eps slice with a length greater than or equal to the length of packets.
// These lengths must not exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
// depending on the platform-specific implementation.
type Bind interface {
// Open puts the Bind into a listening state on a given port and reports the actual
// port that it bound to. Passing zero results in a random selection.
// fns is the set of functions that will be called to receive packets.
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
// Close closes the Bind listener.
// All fns returned by Open must return net.ErrClosed after a call to Close.
Close() error
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
// Send writes one or more packets in bufs to address ep. The length of
// bufs must not exceed BatchSize().
Send(bufs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
// BatchSize is the number of buffers expected to be passed to
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being
// tied to a single network interface. Used by wireguard-windows.
type BindSocketToInterface interface {
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
}
// PeekLookAtSocketFd is implemented by Bind objects that support having their
// file descriptor peeked at. Used by wireguard-android.
type PeekLookAtSocketFd interface {
PeekLookAtSocketFd4() (fd int, err error)
PeekLookAtSocketFd6() (fd int, err error)
}
// An Endpoint maintains the source/destination caching for a peer.
//
// dst: the remote address of a peer ("endpoint" in uapi terminology)
// src: the local address from which datagrams originate going to the peer
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() netip.Addr
SrcIP() netip.Addr
}
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)
func (fn ReceiveFunc) PrettyName() string {
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
// 0. cheese/taco.beansIPv6.func12.func21218-fm
name = strings.TrimSuffix(name, "-fm")
// 1. cheese/taco.beansIPv6.func12.func21218
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
name = name[idx+1:]
// 2. taco.beansIPv6.func12.func21218
}
for {
var idx int
for idx = len(name) - 1; idx >= 0; idx-- {
if name[idx] < '0' || name[idx] > '9' {
break
}
}
if idx == len(name)-1 {
break
}
const dotFunc = ".func"
if !strings.HasSuffix(name[:idx+1], dotFunc) {
break
}
name = name[:idx+1-len(dotFunc)]
// 3. taco.beansIPv6.func12
// 4. taco.beansIPv6
}
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
name = name[idx+1:]
// 5. beansIPv6
}
if name == "" {
return fmt.Sprintf("%p", fn)
}
if strings.HasSuffix(name, "IPv4") {
return "v4"
}
if strings.HasSuffix(name, "IPv6") {
return "v6"
}
return name
}

View File

@@ -0,0 +1,42 @@
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"net"
"syscall"
)
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
// the max supported by a default configuration of macOS. Some platforms will
// silently clamp the value to other maximums, such as linux clamping to
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
// around this limitation)
const socketBufferSize = 7 << 20
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
}

View File

@@ -0,0 +1,62 @@
//go:build linux
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"fmt"
"runtime"
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
// fail silently - the result of failure is lower performance on very fast
// links or high latency links.
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
// Set up to *mem_max
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
// Set beyond *mem_max if CAP_NET_ADMIN
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
})
},
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
func(network, address string, c syscall.RawConn) error {
var err error
switch network {
case "udp4":
if runtime.GOOS != "android" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
})
}
case "udp6":
c.Control(func(fd uintptr) {
if runtime.GOOS != "android" {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
if err != nil {
return
}
}
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
default:
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
}
return err
},
)
}

9
wgstack/conn/default.go Normal file
View File

@@ -0,0 +1,9 @@
//go:build !windows
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
func NewDefaultBind() Bind { return NewStdNetBind() }

View File

@@ -0,0 +1,12 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func errShouldDisableUDPGSO(err error) bool {
return false
}

View File

@@ -0,0 +1,26 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"os"
"golang.org/x/sys/unix"
)
func errShouldDisableUDPGSO(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not have
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
// See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
return serr.Err == unix.EIO
}
return false
}

View File

@@ -0,0 +1,15 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net"
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
return
}

View File

@@ -0,0 +1,29 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"golang.org/x/sys/unix"
)
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
rc, err := conn.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
txOffload = errSyscall == nil
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
rxOffload = errSyscall == nil && opt == 1
})
if err != nil {
return false, false
}
return txOffload, rxOffload
}

View File

@@ -0,0 +1,21 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
func setGSOSize(control *[]byte, gsoSize uint16) {
}
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
// offloading control data.
const gsoControlSize = 0

65
wgstack/conn/gso_linux.go Normal file
View File

@@ -0,0 +1,65 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"unsafe"
"golang.org/x/sys/unix"
)
const (
sizeOfGSOData = 2
)
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
var gso uint16
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
return int(gso), nil
}
}
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
// data in control untouched.
func setGSOSize(control *[]byte, gsoSize uint16) {
existingLen := len(*control)
avail := cap(*control) - existingLen
space := unix.CmsgSpace(sizeOfGSOData)
if avail < space {
return
}
*control = (*control)[:cap(*control)]
gsoControl := (*control)[existingLen:]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
*control = (*control)[:existingLen+space]
}
// gsoControlSize returns the recommended buffer size for pooling UDP
// offloading control data.
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)

64
wgstack/conn/mark_unix.go Normal file
View File

@@ -0,0 +1,64 @@
//go:build linux || openbsd || freebsd
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"runtime"
"golang.org/x/sys/unix"
)
var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
if s.ipv4 != nil {
fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
if s.ipv6 != nil {
fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,42 @@
//go:build !linux || android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net/netip"
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return ""
}
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
// {get,set}srcControl feature set, but use alternatively named flags and need
// ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
const stickyControlSize = 0
const StdNetSupportsStickySockets = false

View File

@@ -0,0 +1,116 @@
//go:build linux && !android
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
ep.ClearSrc()
var (
hdr unix.Cmsghdr
data []byte
rem []byte = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return
}
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
ep.src.ifidx = info.Ifindex
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
ep.src.Addr = netip.AddrFrom16(info.Addr)
ep.src.ifidx = int32(info.Ifindex)
return
}
}
}
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
// panics if buf is of insufficient size.
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
size := int(unsafe.Sizeof(t))
if len(buf) < size {
panic("pktInfoFromBuf: buffer too small")
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
return t
}
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
// and source ifindex found in ep. control's len will be set to 0 in the event
// that ep is a default value.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
*control = (*control)[:cap(*control)]
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
*control = (*control)[:0]
return
}
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
*control = (*control)[:0]
return
}
if len(*control) < srcControlSize {
*control = (*control)[:0]
return
}
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
if ep.SrcIP().Is4() {
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = ep.src.ifidx
if ep.SrcIP().IsValid() {
info.Spec_dst = ep.SrcIP().As4()
}
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
} else {
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = uint32(ep.src.ifidx)
if ep.SrcIP().IsValid() {
info.Addr = ep.SrcIP().As16()
}
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
}
}
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true

42
wgstack/tun/checksum.go Normal file
View File

@@ -0,0 +1,42 @@
package tun
import "encoding/binary"
// TODO: Explore SIMD and/or other assembly optimizations.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
i := 0
n := len(b)
for n >= 4 {
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
n -= 4
i += 4
}
for n >= 2 {
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
n -= 2
i += 2
}
if n == 1 {
ac += uint64(b[i]) << 8
}
return ac
}
func checksum(b []byte, initial uint64) uint16 {
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumNoFold(srcAddr, 0)
sum = checksumNoFold(dstAddr, sum)
sum = checksumNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumNoFold(tmp, sum)
}

3
wgstack/tun/export.go Normal file
View File

@@ -0,0 +1,3 @@
package tun
const VirtioNetHdrLen = virtioNetHdrLen

View File

@@ -0,0 +1,630 @@
//go:build linux
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package tun
import (
"bytes"
"encoding/binary"
"errors"
"io"
"unsafe"
wgconn "github.com/slackhq/nebula/wgstack/conn"
"golang.org/x/sys/unix"
)
var ErrTooManySegments = errors.New("tun: too many segments for TSO")
const tcpFlagsOffset = 13
const (
tcpFlagFIN uint8 = 0x01
tcpFlagPSH uint8 = 0x08
tcpFlagACK uint8 = 0x10
)
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
// kernel symbol is virtio_net_hdr.
type virtioNetHdr struct {
flags uint8
gsoType uint8
hdrLen uint16
gsoSize uint16
csumStart uint16
csumOffset uint16
}
func (v *virtioNetHdr) decode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
return nil
}
func (v *virtioNetHdr) encode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
return nil
}
const (
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
// flowKey represents the key for a flow.
type flowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
}
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
type tcpGROTable struct {
itemsByFlow map[flowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
itemsByFlow: make(map[flowKey][]tcpGROItem, wgconn.IdealBatchSize),
itemsPool: make([][]tcpGROItem, wgconn.IdealBatchSize),
}
for i := range t.itemsPool {
t.itemsPool[i] = make([]tcpGROItem, 0, wgconn.IdealBatchSize)
}
return t
}
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
key := flowKey{}
addrSize := dstAddr - srcAddr
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
iphLen: uint8(tcphOffset),
tcphLen: uint8(tcphLen),
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
}
items, ok := t.itemsByFlow[key]
if !ok {
items = t.newItems()
}
items = append(items, item)
t.itemsByFlow[key] = items
}
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items, _ := t.itemsByFlow[item.key]
items[i] = item
}
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
}
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
key flowKey
sentSeq uint32 // the sequence number
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
tcphLen uint8 // tcp header len
pshSet bool // psh flag is set
}
func (t *tcpGROTable) newItems() []tcpGROItem {
var items []tcpGROItem
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
return items
}
func (t *tcpGROTable) reset() {
for k, items := range t.itemsByFlow {
items = items[:0]
t.itemsPool = append(t.itemsPool, items)
delete(t.itemsByFlow, k)
}
}
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
const (
coalescePrepend canCoalesce = -1
coalesceUnavailable canCoalesce = 0
coalesceAppend canCoalesce = 1
)
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if tcphLen != item.tcphLen {
// cannot coalesce with unequal tcp options len
return coalesceUnavailable
}
if tcphLen > 20 {
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
// cannot coalesce with unequal tcp options
return coalesceUnavailable
}
}
if pkt[0]>>4 == 6 {
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
// cannot coalesce with unequal Traffic class values
return coalesceUnavailable
}
if pkt[7] != pktTarget[7] {
// cannot coalesce with unequal Hop limit values
return coalesceUnavailable
}
} else {
if pkt[1] != pktTarget[1] {
// cannot coalesce with unequal ToS values
return coalesceUnavailable
}
if pkt[6]>>5 != pktTarget[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return coalesceUnavailable
}
if pkt[8] != pktTarget[8] {
// cannot coalesce with unequal TTL values
return coalesceUnavailable
}
}
// seq adjacency
lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
if item.pshSet {
// We cannot append to a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
if pshSet {
// We cannot prepend with a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if gsoSize < item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
if gsoSize > item.gsoSize && item.numMerged > 0 {
// There's at least one previous merge, and we're larger than all
// previous. This would put multiple smaller packets on the end.
return coalesceUnavailable
}
return coalescePrepend
}
return coalesceUnavailable
}
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
tcpTotalLen := uint16(len(pkt) - int(iphLen))
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
// packets.
type coalesceResult int
const (
coalesceInsufficientCap coalesceResult = 0
coalescePSHEnding coalesceResult = 1
coalesceItemInvalidCSum coalesceResult = 2
coalescePktInvalidCSum coalesceResult = 3
coalesceSuccess coalesceResult = 4
)
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
// item, returning the outcome. This function may swap bufs elements in the
// event of a prepend as item's bufs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
var pktHead []byte // the packet that will end up at the front
headersLen := item.iphLen + item.tcphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
// Copy data
if mode == coalescePrepend {
pktHead = pkt
if cap(pkt)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if pshSet {
return coalescePSHEnding
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
extendBy := coalescedLen - len(pktHead)
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
// Flip the slice headers in bufs as part of prepend. The index of item
// is already being tracked for writing.
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
} else {
pktHead = bufs[item.bufsIndex][bufsOffset:]
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
// We are appending a segment with PSH set.
item.pshSet = pshSet
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
}
if gsoSize > item.gsoSize {
item.gsoSize = gsoSize
}
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(headersLen),
gsoSize: uint16(item.gsoSize),
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
// (IPv4) header checksum.
if isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pktHead[10], pktHead[11] = 0, 0 // clear checksum field
binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
}
hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
// Calculate the pseudo header checksum and place it at the TCP checksum
// offset. Downstream checksum offloading will combine this with computation
// of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := bufsOffset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
item.numMerged++
return coalesceSuccess
}
const (
ipv4FlagMoreFragments uint8 = 0x20
)
const (
ipv4SrcAddrOffset = 12
ipv6SrcAddrOffset = 8
maxUint16 = 1<<16 - 1
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It will return false when pktI is not
// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
// should be written to the Device.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return false
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return false
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return false
}
}
if len(pkt) < iphLen {
return false
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return false
}
if len(pkt) < iphLen+tcphLen {
return false
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return false
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
var pshSet bool
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return false
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return false
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return false
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
// more efficient if there are multiple items for a given flow. This
// also enables a natural table.deleteAt() in the
// coalesceItemInvalidCSum case without the need for index tracking.
// This algorithm makes a best effort to coalesce in the event of
// unordered packets, where pkt may land anywhere in items from a
// sequence number perspective, however once an item is inserted into
// the table it is never compared across other items later.
item := items[i]
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
if can != coalesceUnavailable {
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return true
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return false
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return false
}
func isTCP4NoIPOptions(b []byte) bool {
if len(b) < 40 {
return false
}
if b[0]>>4 != 4 {
return false
}
if b[0]&0x0F != 5 {
return false
}
if b[9] != unix.IPPROTO_TCP {
return false
}
return true
}
func isTCP6NoEH(b []byte) bool {
if len(b) < 60 {
return false
}
if b[0]>>4 != 6 {
return false
}
if b[6] != unix.IPPROTO_TCP {
return false
}
return true
}
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
// and recycle them across vectors of packets.
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
for i := range bufs {
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
var coalesced bool
switch {
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
}
if !coalesced {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
*toWrite = append(*toWrite, i)
}
}
return nil
}
// tcpTSO splits packets from in into outBuffs, writing the size of each
// element into sizes. It returns the number of buffers populated, and/or an
// error.
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
iphLen := int(hdr.csumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
in[10], in[11] = 0, 0 // clear ipv4 header checksum
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
nextSegmentDataAt := int(hdr.hdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
if i == len(outBuffs) {
return i - 1, ErrTooManySegments
}
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
if nextSegmentEnd > len(in) {
nextSegmentEnd = len(in)
}
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
totalLen := int(hdr.hdrLen) + segmentDataLen
sizes[i] = totalLen
out := outBuffs[i][outOffset:]
copy(out, in[:iphLen])
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
if i > 0 {
id := binary.BigEndian.Uint16(out[4:])
id += uint16(i)
binary.BigEndian.PutUint16(out[4:], id)
}
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
ipv4CSum := ^checksum(out[:iphLen], 0)
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
} else {
// For IPv6 we are responsible for updating the payload length field.
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
// TCP header
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
if nextSegmentEnd != len(in) {
// FIN and PSH should only be set on last segment
clearFlags := tcpFlagFIN | tcpFlagPSH
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
}
// payload
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
// TCP checksum
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
nextSegmentDataAt += int(hdr.gsoSize)
}
return i, nil
}
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
cSumAt := cSumStart + cSumOffset
// The initial value at the checksum offset should be summed with the
// checksum we compute. This is typically the pseudo-header checksum.
initial := binary.BigEndian.Uint16(in[cSumAt:])
in[cSumAt], in[cSumAt+1] = 0, 0
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
return nil
}

52
wgstack/tun/tun.go Normal file
View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package tun
import (
"os"
)
type Event int
const (
EventUp = 1 << iota
EventDown
EventMTUUpdate
)
type Device interface {
// File returns the file descriptor of the device.
File() *os.File
// Read one or more packets from the Device (without any additional headers).
// On a successful read it returns the number of packets read, and sets
// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
// A nonzero offset can be used to instruct the Device on where to begin
// reading into each element of the bufs slice.
Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
// Write one or more packets to the device (without any additional headers).
// On a successful write it returns the number of packets written. A nonzero
// offset can be used to instruct the Device on where to begin writing from
// each packet contained within the bufs slice.
Write(bufs [][]byte, offset int) (int, error)
// MTU returns the MTU of the Device.
MTU() (int, error)
// Name returns the current name of the Device.
Name() (string, error)
// Events returns a channel of type Event, which is fed Device events.
Events() <-chan Event
// Close stops the Device and closes the Event channel.
Close() error
// BatchSize returns the preferred/max number of packets that can be read or
// written in a single read/write call. BatchSize must not change over the
// lifetime of a Device.
BatchSize() int
}

664
wgstack/tun/tun_linux.go Normal file
View File

@@ -0,0 +1,664 @@
//go:build linux
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package tun
/* Implementation of the TUN device interface for linux
*/
import (
"errors"
"fmt"
"os"
"sync"
"syscall"
"time"
"unsafe"
wgconn "github.com/slackhq/nebula/wgstack/conn"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
const (
cloneDevicePath = "/dev/net/tun"
ifReqSize = unix.IFNAMSIZ + 64
)
type NativeTun struct {
tunFile *os.File
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
batchSize int
vnetHdr bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
readOpMu sync.Mutex // readOpMu guards readBuff
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
toWrite []int
tcp4GROTable, tcp6GROTable *tcpGROTable
}
func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
func (tun *NativeTun) routineHackListener() {
defer tun.hackListenerClosed.Unlock()
/* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch.
*/
last := 0
const (
up = 1
down = 2
)
for {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return
}
err2 := sysconn.Control(func(fd uintptr) {
_, err = unix.Write(int(fd), nil)
})
if err2 != nil {
return
}
switch err {
case unix.EINVAL:
if last != up {
// If the tunnel is up, it reports that write() is
// allowed but we provided invalid data.
tun.events <- EventUp
last = up
}
case unix.EIO:
if last != down {
// If the tunnel is down, it reports that no I/O
// is possible, without checking our provided data.
tun.events <- EventDown
last = down
}
default:
return
}
select {
case <-time.After(time.Second):
// nothing
case <-tun.statusListenersShutdown:
return
}
}
}
func createNetlinkSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
}
err = unix.Bind(sock, saddr)
if err != nil {
return -1, err
}
return sock, nil
}
func (tun *NativeTun) routineNetlinkListener() {
defer func() {
unix.Close(tun.netlinkSock)
tun.hackListenerClosed.Lock()
close(tun.events)
tun.netlinkCancel.Close()
}()
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !tun.netlinkCancel.ReadyRead() {
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
return
}
}
if err != nil {
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
return
}
select {
case <-tun.statusListenersShutdown:
return
default:
}
wasEverUp := false
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if int(hdr.Len) > len(remain) {
break
}
switch hdr.Type {
case unix.NLMSG_DONE:
remain = []byte{}
case unix.RTM_NEWLINK:
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
remain = remain[hdr.Len:]
if info.Index != tun.index {
// not our interface
continue
}
if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- EventUp
wasEverUp = true
}
if info.Flags&unix.IFF_RUNNING == 0 {
// Don't emit EventDown before we've ever emitted EventUp.
// This avoids a startup race with HackListener, which
// might detect Up before we have finished reporting Down.
if wasEverUp {
tun.events <- EventDown
}
}
tun.events <- EventMTUUpdate
default:
remain = remain[hdr.Len:]
}
}
}
}
func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
var ifr [ifReqSize]byte
copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFINDEX),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return 0, errno
}
return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
}
func (tun *NativeTun) setMTU(n int) error {
name, err := tun.Name()
if err != nil {
return err
}
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
var ifr [ifReqSize]byte
copy(ifr[:], name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return errno
}
return nil
}
func (tun *NativeTun) routineNetlinkRead() {
defer func() {
unix.Close(tun.netlinkSock)
tun.hackListenerClosed.Lock()
close(tun.events)
tun.netlinkCancel.Close()
}()
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !tun.netlinkCancel.ReadyRead() {
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
return
}
}
if err != nil {
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
return
}
wasEverUp := false
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if int(hdr.Len) > len(remain) {
break
}
switch hdr.Type {
case unix.NLMSG_DONE:
remain = []byte{}
case unix.RTM_NEWLINK:
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
remain = remain[hdr.Len:]
if info.Index != tun.index {
continue
}
if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- EventUp
wasEverUp = true
}
if info.Flags&unix.IFF_RUNNING == 0 {
if wasEverUp {
tun.events <- EventDown
}
}
tun.events <- EventMTUUpdate
default:
remain = remain[hdr.Len:]
}
}
}
}
func (tun *NativeTun) routineNetlink() {
var err error
tun.netlinkSock, err = createNetlinkSocket()
if err != nil {
tun.errors <- fmt.Errorf("failed to create netlink socket: %w", err)
return
}
tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
if err != nil {
tun.errors <- fmt.Errorf("failed to create netlink cancel: %w", err)
return
}
go tun.routineNetlinkListener()
}
func (tun *NativeTun) Close() error {
var err1, err2 error
tun.closeOnce.Do(func() {
if tun.statusListenersShutdown != nil {
close(tun.statusListenersShutdown)
if tun.netlinkCancel != nil {
err1 = tun.netlinkCancel.Cancel()
}
} else if tun.events != nil {
close(tun.events)
}
err2 = tun.tunFile.Close()
})
if err1 != nil {
return err1
}
return err2
}
func (tun *NativeTun) BatchSize() int {
return tun.batchSize
}
const (
// TODO: support TSO with ECN bits
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
)
func (tun *NativeTun) initFromFlags(name string) error {
sc, err := tun.tunFile.SyscallConn()
if err != nil {
return err
}
if e := sc.Control(func(fd uintptr) {
var (
ifr *unix.Ifreq
)
ifr, err = unix.NewIfreq(name)
if err != nil {
return
}
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
if err != nil {
return
}
got := ifr.Uint16()
if got&unix.IFF_VNET_HDR != 0 {
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
if err != nil {
return
}
tun.vnetHdr = true
tun.batchSize = wgconn.IdealBatchSize
} else {
tun.batchSize = 1
}
}); e != nil {
return e
}
return err
}
// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
}
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
tun, err := CreateTUNFromFile(fd, mtu)
if err != nil {
return nil, err
}
if name != "tun" {
if err := tun.(*NativeTun).initFromFlags(name); err != nil {
tun.Close()
return nil, fmt.Errorf("CreateTUN(%q) failed to set flags: %w", name, err)
}
}
return tun, nil
}
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
errors: make(chan error, 5),
events: make(chan Event, 5),
}
name, err := tun.Name()
if err != nil {
return nil, fmt.Errorf("failed to determine TUN name: %w", err)
}
if err := tun.initFromFlags(name); err != nil {
return nil, fmt.Errorf("failed to query TUN flags: %w", err)
}
if tun.batchSize == 0 {
tun.batchSize = 1
}
tun.index, err = getIFIndex(name)
if err != nil {
return nil, fmt.Errorf("failed to get TUN index: %w", err)
}
if err = tun.setMTU(mtu); err != nil {
return nil, fmt.Errorf("failed to set MTU: %w", err)
}
tun.statusListenersShutdown = make(chan struct{})
go tun.routineNetlink()
if tun.batchSize == 0 {
tun.batchSize = 1
}
tun.tcp4GROTable = newTCPGROTable()
tun.tcp6GROTable = newTCPGROTable()
return tun, nil
}
func (tun *NativeTun) Name() (string, error) {
tun.nameOnce.Do(tun.initNameCache)
return tun.nameCache, tun.nameErr
}
func (tun *NativeTun) initNameCache() {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
tun.nameErr = err
return
}
err = sysconn.Control(func(fd uintptr) {
var ifr [ifReqSize]byte
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
fd,
uintptr(unix.TUNGETIFF),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
tun.nameErr = errno
return
}
tun.nameCache = unix.ByteSliceToString(ifr[:])
})
if err != nil && tun.nameErr == nil {
tun.nameErr = err
}
}
func (tun *NativeTun) MTU() (int, error) {
name, err := tun.Name()
if err != nil {
return 0, err
}
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
var ifr [ifReqSize]byte
copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return 0, errno
}
return int(*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
}
func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
tun.writeOpMu.Lock()
defer func() {
tun.tcp4GROTable.reset()
tun.tcp6GROTable.reset()
tun.writeOpMu.Unlock()
}()
var (
errs error
total int
)
tun.toWrite = tun.toWrite[:0]
if tun.vnetHdr {
err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
if err != nil {
return 0, err
}
offset -= virtioNetHdrLen
} else {
for i := range bufs {
tun.toWrite = append(tun.toWrite, i)
}
}
for _, bufsI := range tun.toWrite {
n, err := tun.tunFile.Write(bufs[bufsI][offset:])
if errors.Is(err, syscall.EBADFD) {
return total, os.ErrClosed
}
if err != nil {
errs = errors.Join(errs, err)
} else {
total += n
}
}
return total, errs
}
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
// each buffer. It mutates sizes to reflect the size of each element of bufs,
// and returns the number of packets read.
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
var hdr virtioNetHdr
if err := hdr.decode(in); err != nil {
return 0, err
}
in = in[virtioNetHdrLen:]
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset); err != nil {
return 0, err
}
}
if len(in) > len(bufs[0][offset:]) {
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
}
n := copy(bufs[0][offset:], in)
sizes[0] = n
return 1, nil
}
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
case 6:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
if len(in) <= int(hdr.csumStart+12) {
return 0, errors.New("packet is too short")
}
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.hdrLen = hdr.csumStart + tcpHLen
if len(in) < int(hdr.hdrLen) {
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
if hdr.hdrLen < hdr.csumStart {
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
cSumAt := int(hdr.csumStart + hdr.csumOffset)
if cSumAt+1 >= len(in) {
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
return tcpTSO(in, hdr, bufs, sizes, offset)
}
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
tun.readOpMu.Lock()
defer tun.readOpMu.Unlock()
select {
case err := <-tun.errors:
return 0, err
default:
readInto := bufs[0][offset:]
if tun.vnetHdr {
readInto = tun.readBuff[:]
}
n, err := tun.tunFile.Read(readInto)
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
if err != nil {
return 0, err
}
if tun.vnetHdr {
return handleVirtioRead(readInto[:n], bufs, sizes, offset)
}
sizes[0] = n
return 1, nil
}
}