Compare commits

..

15 Commits

Author SHA1 Message Date
Jay Wren
be90e4aa05 handle virtio header in ctrl messages 2025-11-19 17:09:39 -05:00
Jay Wren
bc9711df68 batch more writes 2025-11-19 16:58:59 -05:00
Jay Wren
4e333c76ba write batching 2025-11-19 14:03:36 -05:00
Jay Wren
f29e21b411 don't register metrics in loops 2025-11-19 13:25:25 -05:00
Jay Wren
8b32382cd9 zero copy even with virtioheder 2025-11-19 12:03:38 -05:00
Jay Wren
518a78c9d2 preallocate nonce buffer 2025-11-18 14:19:05 -05:00
Jay Wren
7c3708561d instruments 2025-11-14 14:43:51 -05:00
Jay Wren
a62ffca975 fix 32bit 2025-11-13 15:10:51 -05:00
Jay Wren
226787ea1f prealloc them buffers 2025-11-11 15:20:50 -05:00
Jay Wren
b2bc6a09ca write in batches 2025-11-11 15:06:45 -05:00
Jay Wren
0f9b33aa36 reduce copying 2025-11-11 14:51:53 -05:00
Jay Wren
ef0a022375 more nonblocking 2025-11-11 14:22:40 -05:00
Jay Wren
b68e504865 hrm 2025-11-11 13:15:30 -05:00
Jay Wren
3344a840d1 just using the wg library works 2025-11-11 10:55:39 -05:00
Jay Wren
2bc9863e66 only wg tun, no batching 2025-11-10 16:54:00 -05:00
60 changed files with 1628 additions and 1381 deletions

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -10,7 +10,7 @@ jobs:
name: Build Linux/BSD All name: Build Linux/BSD All
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release mv build/*.tar.gz release
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: linux-latest name: linux-latest
path: release path: release
@@ -33,7 +33,7 @@ jobs:
name: Build Windows name: Build Windows
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -55,7 +55,7 @@ jobs:
mv dist\windows\wintun build\dist\windows\ mv dist\windows\wintun build\dist\windows\
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: windows-latest name: windows-latest
path: build path: build
@@ -66,7 +66,7 @@ jobs:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -104,7 +104,7 @@ jobs:
fi fi
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: darwin-latest name: darwin-latest
path: ./release/* path: ./release/*
@@ -124,11 +124,11 @@ jobs:
# be overwritten # be overwritten
- name: Checkout code - name: Checkout code
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Download artifacts - name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/download-artifact@v6 uses: actions/download-artifact@v4
with: with:
name: linux-latest name: linux-latest
path: artifacts path: artifacts
@@ -160,10 +160,10 @@ jobs:
needs: [build-linux, build-darwin, build-windows] needs: [build-linux, build-darwin, build-windows]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- name: Download artifacts - name: Download artifacts
uses: actions/download-artifact@v6 uses: actions/download-artifact@v4
with: with:
path: artifacts path: artifacts

View File

@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -32,7 +32,7 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v9 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.5 version: v2.5
@@ -45,7 +45,7 @@ jobs:
- name: Build test mobile - name: Build test mobile
run: make build-test-mobile run: make build-test-mobile
- uses: actions/upload-artifact@v5 - uses: actions/upload-artifact@v4
with: with:
name: e2e packet flow linux-latest name: e2e packet flow linux-latest
path: e2e/mermaid/linux-latest path: e2e/mermaid/linux-latest
@@ -56,7 +56,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -77,7 +77,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -98,7 +98,7 @@ jobs:
os: [windows-latest, macos-latest] os: [windows-latest, macos-latest]
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -115,7 +115,7 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v9 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.5 version: v2.5
@@ -125,7 +125,7 @@ jobs:
- name: End 2 end - name: End 2 end
run: make e2evv run: make e2evv
- uses: actions/upload-artifact@v5 - uses: actions/upload-artifact@v4
with: with:
name: e2e packet flow ${{ matrix.os }} name: e2e packet flow ${{ matrix.os }}
path: e2e/mermaid/${{ matrix.os }} path: e2e/mermaid/${{ matrix.os }}

107
bits.go
View File

@@ -9,13 +9,14 @@ type Bits struct {
length uint64 length uint64
current uint64 current uint64
bits []bool bits []bool
firstSeen bool
lostCounter metrics.Counter lostCounter metrics.Counter
dupeCounter metrics.Counter dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter outOfWindowCounter metrics.Counter
} }
func NewBits(bits uint64) *Bits { func NewBits(bits uint64) *Bits {
b := &Bits{ return &Bits{
length: bits, length: bits,
bits: make([]bool, bits, bits), bits: make([]bool, bits, bits),
current: 0, current: 0,
@@ -23,37 +24,34 @@ func NewBits(bits uint64) *Bits {
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil), outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
} }
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
return b
} }
func (b *Bits) Check(l *logrus.Logger, i uint64) bool { func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true. // If i is the next number, return true.
if i > b.current { if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
return true return true
} }
// If i is within the window, check if it's been set already. // If i is within the window, check if it's been set already. The first window will fail this check
if i > b.current-b.length || i < b.length && b.current < b.length { if i > b.current-b.length {
return !b.bits[i%b.length]
}
// If i is within the first window
if i < b.length {
return !b.bits[i%b.length] return !b.bits[i%b.length]
} }
// Not within the window // Not within the window
if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i) l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
}
return false return false
} }
func (b *Bits) Update(l *logrus.Logger, i uint64) bool { func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current. // If i is the next number, return true and update current.
if i == b.current+1 { if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter // Report missed packets, we can only understand what was missed after the first window has been gone through
// The very first window can only be tracked as lost once we are on the 2nd window or greater if i > b.length && b.bits[i%b.length] == false {
if b.bits[i%b.length] == false && i > b.length {
b.lostCounter.Inc(1) b.lostCounter.Inc(1)
} }
b.bits[i%b.length] = true b.bits[i%b.length] = true
@@ -61,32 +59,61 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return true return true
} }
// If i is a jump, adjust the window, record lost, update current, and return true // If i packet is greater than current but less than the maximum length of our bitmap,
if i > b.current { // flip everything in between to false and move ahead.
lost := int64(0) if i > b.current && i < b.current+b.length {
// Zero out the bits between the current and the new counter value, limited by the window size, // In between current and i need to be zero'd to allow those packets to come in later
// since the window is shifting for n := b.current + 1; n < i; n++ {
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
lost++
}
b.bits[n%b.length] = false b.bits[n%b.length] = false
} }
// Only record any skipped packets as a result of the window moving further than the window length b.bits[i%b.length] = true
// Any loss within the new window will be accounted for in future calls b.current = i
lost += max(0, int64(i-b.current-b.length)) //l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
return true
}
// If i is greater than the delta between current and the total length of our bitmap,
// just flip everything in the map and move ahead.
if i >= b.current+b.length {
// The current window loss will be accounted for later, only record the jump as loss up until then
lost := maxInt64(0, int64(i-b.current-b.length))
//TODO: explain this
if b.current == 0 {
lost++
}
for n := range b.bits {
// Don't want to count the first window as a loss
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
//if b.bits[n] == false {
// lost++
//}
b.bits[n] = false
}
b.lostCounter.Inc(lost) b.lostCounter.Inc(lost)
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
Debug("Receive window")
}
b.bits[i%b.length] = true b.bits[i%b.length] = true
b.current = i b.current = i
return true return true
} }
// If i is within the current window but below the current counter, // Allow for the 0 packet to come in within the first window
// Check to see if it's a duplicate if i == 0 && b.firstSeen == false && b.current < b.length {
if i > b.current-b.length || i < b.length && b.current < b.length { b.firstSeen = true
if b.current == i || b.bits[i%b.length] == true { b.bits[i%b.length] = true
return true
}
// If i is within the window of current minus length (the total pat window size),
// allow it and flip to true but to NOT change current. We also have to account for the first window
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
if b.current == i {
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window") Debug("Receive window")
@@ -95,8 +122,18 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false return false
} }
if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true b.bits[i%b.length] = true
return true return true
} }
// In all other cases, fail and don't change current. // In all other cases, fail and don't change current.
@@ -110,3 +147,11 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
} }
return false return false
} }
func maxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@@ -15,41 +15,48 @@ func TestBits(t *testing.T) {
assert.Len(t, b.bits, 10) assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work. // This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1)) assert.True(t, b.Check(l, 1))
assert.True(t, b.Update(l, 1)) u := b.Update(l, 1)
assert.True(t, u)
assert.EqualValues(t, 1, b.current) assert.EqualValues(t, 1, b.current)
g := []bool{true, true, false, false, false, false, false, false, false, false} g := []bool{false, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two // Receive two
assert.True(t, b.Check(l, 2)) assert.True(t, b.Check(l, 2))
assert.True(t, b.Update(l, 2)) u = b.Update(l, 2)
assert.True(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
g = []bool{true, true, true, false, false, false, false, false, false, false} g = []bool{false, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two again - it will fail // Receive two again - it will fail
assert.False(t, b.Check(l, 2)) assert.False(t, b.Check(l, 2))
assert.False(t, b.Update(l, 2)) u = b.Update(l, 2)
assert.False(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element // Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15)) assert.True(t, b.Check(l, 15))
assert.True(t, b.Update(l, 15)) u = b.Update(l, 15)
assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false} g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window // Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14)) assert.True(t, b.Check(l, 14))
assert.True(t, b.Update(l, 14)) u = b.Update(l, 14)
assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window // Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(l, 5)) assert.False(t, b.Check(l, 5))
assert.False(t, b.Update(l, 5)) u = b.Update(l, 5)
assert.False(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
@@ -62,29 +69,10 @@ func TestBits(t *testing.T) {
// Walk through a few windows in order // Walk through a few windows in order
b = NewBits(10) b = NewBits(10)
for i := uint64(1); i <= 100; i++ { for i := uint64(0); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i)
} }
assert.False(t, b.Check(l, 1), "Out of window check")
}
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
} }
func TestBitsDupeCounter(t *testing.T) { func TestBitsDupeCounter(t *testing.T) {
@@ -136,7 +124,8 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
assert.False(t, b.Update(l, 0)) assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost //tODO: make sure lostcounter doesn't increase in orderly increment
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
} }
@@ -148,6 +137,8 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20)) assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21)) assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22)) assert.True(t, b.Update(l, 22))
@@ -158,7 +149,7 @@ func TestBitsLostCounter(t *testing.T) {
assert.True(t, b.Update(l, 27)) assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28)) assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29)) assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -167,6 +158,8 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9)) assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets // 10 will set 0 index, 0 was already set, no lost packets
@@ -221,62 +214,6 @@ func TestBitsLostCounter(t *testing.T) {
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
} }
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 5))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 6))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) { func BenchmarkBits(b *testing.B) {
z := NewBits(10) z := NewBits(10)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {

View File

@@ -173,8 +173,6 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
var passphrase []byte var passphrase []byte
if !isP11 && *cf.encryption { if !isP11 && *cf.encryption {
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: ")) out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword() passphrase, err = pr.ReadPassword()
@@ -194,7 +192,6 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext") return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
} }
} }
}
var curve cert.Curve var curve cert.Curve
var pub, rawPriv []byte var pub, rawPriv []byte

View File

@@ -171,17 +171,6 @@ func Test_ca(t *testing.T) {
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// test encrypted key with passphrase environment variable
os.Remove(keyF.Name())
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, ca(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// read encrypted key file and verify default params // read encrypted key file and verify default params
rb, _ = os.ReadFile(keyF.Name()) rb, _ = os.ReadFile(keyF.Name())
k, _ := pem.Decode(rb) k, _ := pem.Decode(rb)

View File

@@ -5,28 +5,10 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"runtime/debug"
"strings"
) )
// A version string that can be set with
//
// -ldflags "-X main.Build=SOMEVERSION"
//
// at compile-time.
var Build string var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
type helpError struct { type helpError struct {
s string s string
} }

View File

@@ -43,7 +43,7 @@ type signFlags struct {
func newSignFlags() *signFlags { func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {} sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA") sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
@@ -116,10 +116,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
// naively attempt to decode the private key as though it is not encrypted // naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if errors.Is(err, cert.ErrPrivateKeyEncrypted) { if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
var passphrase []byte
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
// ask for a passphrase until we get one // ask for a passphrase until we get one
var passphrase []byte
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: ")) out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword() passphrase, err = pr.ReadPassword()
@@ -137,7 +135,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if len(passphrase) == 0 { if len(passphrase) == 0 {
return fmt.Errorf("cannot open encrypted ca-key without passphrase") return fmt.Errorf("cannot open encrypted ca-key without passphrase")
} }
}
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
if err != nil { if err != nil {
return fmt.Errorf("error while parsing encrypted ca-key: %s", err) return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
@@ -167,10 +165,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("ca certificate is expired") return fmt.Errorf("ca certificate is expired")
} }
if version == 0 {
version = caCert.Version()
}
// if no duration is given, expire one second before the root expires // if no duration is given, expire one second before the root expires
if *sf.duration <= 0 { if *sf.duration <= 0 {
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
@@ -283,19 +277,21 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
notBefore := time.Now() notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration) notAfter := notBefore.Add(*sf.duration)
switch version { if version == 0 || version == cert.Version1 {
case cert.Version1: // Make sure we at least have an ip
// Make sure we have only one ipv4 address
if len(v4Networks) != 1 { if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
} }
if version == cert.Version1 {
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
if len(v6Networks) > 0 { if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses") return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
} }
if len(v6UnsafeNetworks) > 0 { if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
} }
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
@@ -325,8 +321,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
crts = append(crts, nc) crts = append(crts, nc)
}
case cert.Version2: if version == 0 || version == cert.Version2 {
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
Version: cert.Version2, Version: cert.Version2,
Name: *sf.name, Name: *sf.name,
@@ -354,9 +351,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
crts = append(crts, nc) crts = append(crts, nc)
default:
// this should be unreachable
return fmt.Errorf("invalid version: %d", version)
} }
if !isP11 && *sf.inPubPath == "" { if !isP11 && *sf.inPubPath == "" {

View File

@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
" -unsafe-networks string\n"+ " -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+ " -version uint\n"+
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n", " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
ob.String(), ob.String(),
) )
} }
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -379,15 +379,6 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// test with the proper password in the environment
os.Remove(crtF.Name())
os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the wrong password // test with the wrong password
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
@@ -398,17 +389,6 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// test with the wrong password in environment
ob.Reset()
eb.Reset()
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the user not entering a password // test with the user not entering a password
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()

View File

@@ -4,8 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
@@ -20,17 +18,6 @@ import (
// at compile-time. // at compile-time.
var Build string var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() { func main() {
serviceFlag := flag.String("service", "", "Control the system service.") serviceFlag := flag.String("service", "", "Control the system service.")
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")

View File

@@ -4,8 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
@@ -20,17 +18,6 @@ import (
// at compile-time. // at compile-time.
var Build string var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() { func main() {
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config") configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")

View File

@@ -17,7 +17,7 @@ import (
"dario.cat/mergo" "dario.cat/mergo"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
type C struct { type C struct {

View File

@@ -10,7 +10,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {

View File

@@ -50,6 +50,11 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
} }
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{ hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs, CipherSuite: ncs,
Random: rand.Reader, Random: rand.Reader,
@@ -69,7 +74,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
ci := &ConnectionState{ ci := &ConnectionState{
H: hs, H: hs,
initiator: initiator, initiator: initiator,
window: NewBits(ReplayWindow), window: b,
myCert: crt, myCert: crt,
} }
// always start the counter from 2, as packet 1 and packet 2 are handshake packets. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.

View File

@@ -174,10 +174,6 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap return c.f.hostMap
} }
func (c *Control) GetF() *Interface {
return c.f
}
func (c *Control) GetCertState() *CertState { func (c *Control) GetCertState() *CertState {
return c.f.pki.getCertState() return c.f.pki.getCertState()
} }

View File

@@ -20,7 +20,7 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
@@ -97,41 +97,6 @@ func TestGoodHandshake(t *testing.T) {
theirControl.Stop() theirControl.Stop()
} }
func TestGoodHandshakeNoOverlap(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
empty := []byte{}
t.Log("do something to cause a handshake")
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
myControl.WaitForType(header.Test, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestWrongResponderHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -499,35 +464,6 @@ func TestRelays(t *testing.T) {
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
} }
func TestRelaysDontCareAboutIps(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
}
func TestReestablishRelays(t *testing.T) { func TestReestablishRelays(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
@@ -1291,109 +1227,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
o := m{
"static_host_map": m{
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
},
"lighthouse": m{
"hosts": []string{lhVpnIpNet[0].Addr().String()},
"local_allow_list": m{
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
"10.0.0.0/24": true,
"::/0": false,
},
},
}
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, lhControl, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
lhControl.Start()
myControl.Start()
theirControl.Start()
t.Log("Stand up an ipv6 tunnel between me and them")
assert.True(t, myVpnIpNet[1].Addr().Is6())
assert.True(t, theirVpnIpNet[1].Addr().Is6())
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
lhControl.Stop()
myControl.Stop()
theirControl.Stop()
}
func TestGoodHandshakeUnsafeDest(t *testing.T) {
unsafePrefix := "192.168.6.0/24"
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
myCfg := m{
"tun": m{
"unsafe_routes": []m{route},
},
}
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
t.Logf("my config %v", myConfig)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
spookyDest := netip.MustParseAddr("192.168.6.4")
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet so that we can play with it")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
myControl.InjectUDPPacket(badPacket)
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see my cached packet come through")
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
//reply
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
//wait for reply
theirControl.WaitForType(1, 0, myControl)
theirCachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -22,14 +22,15 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "gopkg.in/yaml.v3"
"go.yaml.in/yaml/v3"
) )
type m = map[string]any type m = map[string]any
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") { for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -55,54 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
budpIp[3] = 239 budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
} }
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides) _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
}
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
}
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
firewallInbound := []m{{
"proto": "any",
"port": "any",
"host": "any",
}}
var unsafeNetworks []netip.Prefix
if sUnsafeNetworks != "" {
firewallInbound = []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "0.0.0.0/0",
}}
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
unsafeNetworks = append(unsafeNetworks, x)
}
}
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
caB, err := caCrt.MarshalPEM() caB, err := caCrt.MarshalPEM()
if err != nil { if err != nil {
@@ -122,7 +76,11 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": "any", "port": "any",
"host": "any", "host": "any",
}}, }},
"inbound": firewallInbound, "inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
}, },
//"handshakes": m{ //"handshakes": m{
// "try_interval": "1s", // "try_interval": "1s",
@@ -308,10 +266,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
// Get both host infos // Get both host infos
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct // Check that both vpn and real addr are correct
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")

View File

@@ -318,50 +318,3 @@ func TestCertMismatchCorrection(t *testing.T) {
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestCrossStackRelaysWork(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
//myVpnV4 := myVpnIpNet[0]
myVpnV6 := myVpnIpNet[1]
relayVpnV4 := relayVpnIpNet[0]
relayVpnV6 := relayVpnIpNet[1]
theirVpnV6 := theirVpnIpNet[0]
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//t.Log("finish up")
//myControl.Stop()
//theirControl.Stop()
//relayControl.Stop()
}

View File

@@ -383,9 +383,8 @@ firewall:
# host: `any` or a literal hostname, ie `test-host` # host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group` # group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
# This can be used to filter destinations when using unsafe_routes.
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true. # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case. # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
# ca_name: An issuing CA name # ca_name: An issuing CA name

View File

@@ -8,7 +8,6 @@ import (
"hash/fnv" "hash/fnv"
"net/netip" "net/netip"
"reflect" "reflect"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -23,7 +22,7 @@ import (
) )
type FirewallInterface interface { type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
} }
type conn struct { type conn struct {
@@ -248,11 +247,22 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
} }
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
if localIp.IsValid() {
lIp = localIp.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload. // We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf( ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
) )
f.rules += ruleString + "\n" f.rules += ruleString + "\n"
@@ -260,7 +270,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added") Info("Firewall rule added")
var ( var (
@@ -287,7 +297,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto) return fmt.Errorf("unknown protocol %v", proto)
} }
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
} }
// GetRuleHash returns a hash representation of all inbound and outbound rules // GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -327,6 +337,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
} }
for i, t := range rs { for i, t := range rs {
var groups []string
r, err := convertRule(l, t, table, i) r, err := convertRule(l, t, table, i)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err) return fmt.Errorf("%s rule #%v; %s", table, i, err)
@@ -336,10 +347,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
} }
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" { if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i) return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
} }
if len(r.Groups) > 0 {
groups = r.Groups
}
if r.Group != "" {
// Check if we have both groups and group provided in the rule config
if len(groups) > 0 {
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
}
groups = []string{r.Group}
}
var sPort, errPort string var sPort, errPort string
if r.Code != "" { if r.Code != "" {
errPort = "code" errPort = "code"
@@ -368,25 +392,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
} }
if r.Cidr != "" && r.Cidr != "any" { var cidr netip.Prefix
_, err = netip.ParsePrefix(r.Cidr) if r.Cidr != "" {
cidr, err = netip.ParsePrefix(r.Cidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
} }
} }
if r.LocalCidr != "" && r.LocalCidr != "any" { var localCidr netip.Prefix
_, err = netip.ParsePrefix(r.LocalCidr) if r.LocalCidr != "" {
localCidr, err = netip.ParsePrefix(r.LocalCidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
} }
} }
if warning := r.sanity(); warning != nil { err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
l.Warnf("%s rule #%v; %s", table, i, warning)
}
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err) return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
} }
@@ -395,10 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return nil return nil
} }
var ErrUnknownNetworkType = errors.New("unknown network type") var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
var ErrPeerRejected = errors.New("remote address is not within a network that we handle") var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table") var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It // Drop returns an error if the packet should be dropped, explaining why. It
@@ -409,31 +429,18 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
return nil return nil
} }
// Make sure remote address matches nebula certificate, and determine how to treat it // Make sure remote address matches nebula certificate
if h.networks == nil { if h.networks != nil {
// Simple case: Certificate has one address and no unsafe networks if !h.networks.Contains(fp.RemoteAddr) {
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
} else { } else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr) // Simple case: Certificate has one address and no unsafe networks
if !ok { if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
switch nwType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
@@ -633,7 +640,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
return false return false
} }
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port was lower than end port") return fmt.Errorf("start port was lower than end port")
} }
@@ -646,7 +653,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
} }
} }
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil { if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
return err return err
} }
} }
@@ -677,7 +684,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
return fp[firewall.PortAny].match(p, c, caPool) return fp[firewall.PortAny].match(p, c, caPool)
} }
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error { func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
fr := func() *FirewallRule { fr := func() *FirewallRule {
return &FirewallRule{ return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR), Hosts: make(map[string]*firewallLocalCIDR),
@@ -691,14 +698,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
fc.Any = fr() fc.Any = fr()
} }
return fc.Any.addRule(f, groups, host, cidr, localCidr) return fc.Any.addRule(f, groups, host, ip, localIp)
} }
if caSha != "" { if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok { if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr() fc.CAShas[caSha] = fr()
} }
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr) err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@@ -708,7 +715,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
if _, ok := fc.CANames[caName]; !ok { if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr() fc.CANames[caName] = fr()
} }
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr) err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@@ -740,24 +747,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
return fc.CANames[s.Certificate.Name()].match(p, c) return fc.CANames[s.Certificate.Name()].match(p, c)
} }
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error { func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR { flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{ return &firewallLocalCIDR{
LocalCIDR: new(bart.Lite), LocalCIDR: new(bart.Lite),
} }
} }
if fr.isAny(groups, host, cidr) { if fr.isAny(groups, host, ip) {
if fr.Any == nil { if fr.Any == nil {
fr.Any = flc() fr.Any = flc()
} }
return fr.Any.addRule(f, localCidr) return fr.Any.addRule(f, localCIDR)
} }
if len(groups) > 0 { if len(groups) > 0 {
nlc := flc() nlc := flc()
err := nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
@@ -773,34 +780,30 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localC
if nlc == nil { if nlc == nil {
nlc = flc() nlc = flc()
} }
err := nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
fr.Hosts[host] = nlc fr.Hosts[host] = nlc
} }
if cidr != "" { if ip.IsValid() {
c, err := netip.ParsePrefix(cidr) nlc, _ := fr.CIDR.Get(ip)
if err != nil {
return err
}
nlc, _ := fr.CIDR.Get(c)
if nlc == nil { if nlc == nil {
nlc = flc() nlc = flc()
} }
err = nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
fr.CIDR.Insert(c, nlc) fr.CIDR.Insert(ip, nlc)
} }
return nil return nil
} }
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool { func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
if len(groups) == 0 && host == "" && cidr == "" { if len(groups) == 0 && host == "" && !ip.IsValid() {
return true return true
} }
@@ -814,7 +817,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
return true return true
} }
if cidr == "any" { if ip.IsValid() && ip.Bits() == 0 {
return true return true
} }
@@ -866,13 +869,8 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
return false return false
} }
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if localCidr == "any" { if !localIp.IsValid() {
flc.Any = true
return nil
}
if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
flc.Any = true flc.Any = true
return nil return nil
@@ -883,13 +881,12 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
} }
return nil return nil
} else if localIp.Bits() == 0 {
flc.Any = true
return nil
} }
c, err := netip.ParsePrefix(localCidr) flc.LocalCIDR.Insert(localIp)
if err != nil {
return err
}
flc.LocalCIDR.Insert(c)
return nil return nil
} }
@@ -910,6 +907,7 @@ type rule struct {
Code string Code string
Proto string Proto string
Host string Host string
Group string
Groups []string Groups []string
Cidr string Cidr string
LocalCidr string LocalCidr string
@@ -951,8 +949,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
m["group"] = v[0] m["group"] = v[0]
} }
r.Group = toString("group", m)
singleGroup := toString("group", m)
if rg, ok := m["groups"]; ok { if rg, ok := m["groups"]; ok {
switch reflect.TypeOf(rg).Kind() { switch reflect.TypeOf(rg).Kind() {
@@ -969,60 +966,9 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
} }
} }
//flatten group vs groups
if singleGroup != "" {
// Check if we have both groups and group provided in the rule config
if len(r.Groups) > 0 {
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
}
r.Groups = []string{singleGroup}
}
return r, nil return r, nil
} }
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
func (r *rule) sanity() error {
//port, proto, local_cidr are AND, no need to check here
//ca_sha and ca_name don't have a wildcard value, no need to check here
groupsEmpty := len(r.Groups) == 0
hostEmpty := r.Host == ""
cidrEmpty := r.Cidr == ""
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
return nil //no content!
}
groupsHasAny := slices.Contains(r.Groups, "any")
if groupsHasAny && len(r.Groups) > 1 {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
}
if r.Host == "any" {
if !groupsEmpty {
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
}
if !cidrEmpty {
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
}
}
if groupsHasAny {
if !hostEmpty && r.Host != "any" {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
}
if !cidrEmpty {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
}
}
//todo alert on cidr-any
return nil
}
func parsePort(s string) (startPort, endPort int32, err error) { func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" { if s == "any" {
startPort = firewall.PortAny startPort = firewall.PortAny

View File

@@ -8,8 +8,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@@ -73,114 +71,85 @@ func TestFirewall_AddRule(t *testing.T) {
ti6, err := netip.ParsePrefix("fd12::34/128") ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0") anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0") anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -205,10 +174,10 @@ func TestFirewall_Drop(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
} }
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -227,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -257,9 +226,6 @@ func TestFirewall_DropV6(t *testing.T) {
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"), LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"), RemoteAddr: netip.MustParseAddr("fd12::34"),
@@ -284,10 +250,10 @@ func TestFirewall_DropV6(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
} }
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -306,28 +272,28 @@ func TestFirewall_DropV6(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -338,12 +304,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
pfix := netip.MustParsePrefix("172.1.1.1/32") pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128") pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
@@ -487,8 +453,6 @@ func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -514,7 +478,7 @@ func TestFirewall_Drop2(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
c1 := cert.CachedCertificate{ c1 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -529,10 +493,10 @@ func TestFirewall_Drop2(t *testing.T) {
peerCert: &c1, peerCert: &c1,
}, },
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
@@ -546,8 +510,6 @@ func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -579,7 +541,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
c2 := cert.CachedCertificate{ c2 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -594,7 +556,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h2.buildNetworks(myVpnNetworksTable, c2.Certificate) h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
c3 := cert.CachedCertificate{ c3 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -609,11 +571,11 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h3.buildNetworks(myVpnNetworksTable, c3.Certificate) h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
@@ -627,7 +589,7 @@ func TestFirewall_Drop3(t *testing.T) {
// Test a remote address match // Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
@@ -635,8 +597,6 @@ func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"), LocalAddr: netip.MustParseAddr("fd12::34"),
@@ -660,12 +620,12 @@ func TestFirewall_Drop3V6(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match // Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool() cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -673,8 +633,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -701,10 +659,10 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -717,7 +675,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -726,7 +684,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -738,8 +696,6 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
c := cert.CachedCertificate{ c := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -761,11 +717,11 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()}, vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
@@ -959,28 +915,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr // Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8") cidr := netip.MustParsePrefix("10.0.0.0/8")
@@ -988,14 +944,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6 // Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8") cidr6 := netip.MustParsePrefix("fd00::/8")
@@ -1003,75 +959,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6 // Test adding rule with local_cidr ipv6
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error // Test Add error
conf = config.NewC(l) conf = config.NewC(l)
@@ -1094,7 +1024,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err := convertRule(l, c, "test", 1) r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups) assert.Equal(t, "group1", r.Group)
// Ensure group array of > 1 is errord // Ensure group array of > 1 is errord
ob.Reset() ob.Reset()
@@ -1114,228 +1044,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err = convertRule(l, c, "test", 1) r, err = convertRule(l, c, "test", 1)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups) assert.Equal(t, "group1", r.Group)
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"host": "bob"},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range noWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
}
yesWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
c["host"] = "any"
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
//reset the list
yesWarningPlease = []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
r.Groups = append(r.Groups, "any")
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
}
type testcase struct {
h *HostInfo
p firewall.Packet
c cert.Certificate
err error
}
func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
}
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
c1 := dummyCert{
name: "host1",
networks: theirPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c1,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
}
for i := range theirPrefixes {
h.vpnAddrs[i] = theirPrefixes[i].Addr()
}
h.buildNetworks(setup.myVpnNetworksTable, &c1)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: theirPrefixes[0].Addr(),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
return testcase{
h: &h,
p: p,
c: &c1,
err: err,
}
}
type testsetup struct {
c dummyCert
myVpnNetworksTable *bart.Lite
fw *Firewall
}
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{
name: "me",
networks: myPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
return newSetupFromCert(t, l, c)
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{
c: c,
fw: fw,
myVpnNetworksTable: myVpnNetworksTable,
}
}
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
t.Run("allow inbound all matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
tc.Test(t, setup.fw)
})
t.Run("allow inbound local matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
tc.Test(t, setup.fw)
})
t.Run("block inbound remote mismatched", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("Block a vpn peer packet", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
tc.Test(t, setup.fw)
})
twoPrefixes := []netip.Prefix{
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
}
t.Run("allow inbound one matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, twoPrefixes...)
tc.Test(t, setup.fw)
})
t.Run("block inbound multimismatch", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
t.Parallel()
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
tc := buildTestCase(setup2, nil, twoPrefixes...)
tc.p.RemoteAddr = twoPrefixes[1].Addr()
tc.Test(t, setup2.fw)
})
t.Run("allow inbound unsafe route", func(t *testing.T) {
t.Parallel()
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
c := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
unsafeNetworks: []netip.Prefix{unsafePrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
unsafeSetup := newSetupFromCert(t, l, c)
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass
})
} }
type addRuleCall struct { type addRuleCall struct {
@@ -1345,8 +1054,8 @@ type addRuleCall struct {
endPort int32 endPort int32
groups []string groups []string
host string host string
ip string ip netip.Prefix
localIp string localIp netip.Prefix
caName string caName string
caSha string caSha string
} }
@@ -1356,7 +1065,7 @@ type mockFirewall struct {
nextCallReturn error nextCallReturn error
} }
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
mf.lastCall = addRuleCall{ mf.lastCall = addRuleCall{
incoming: incoming, incoming: incoming,
proto: proto, proto: proto,

15
go.mod
View File

@@ -8,7 +8,7 @@ require (
github.com/armon/go-radix v1.0.0 github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0 github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.26.0 github.com/gaissmai/bart v0.25.0
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4 github.com/kardianos/service v1.2.4
@@ -22,17 +22,16 @@ require (
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.43.0
golang.org/x/crypto v0.45.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.47.0 golang.org/x/net v0.45.0
golang.org/x/sync v0.18.0 golang.org/x/sync v0.17.0
golang.org/x/sys v0.38.0 golang.org/x/sys v0.37.0
golang.org/x/term v0.37.0 golang.org/x/term v0.36.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.10 google.golang.org/protobuf v1.36.8
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
) )

30
go.sum
View File

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0= github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -155,15 +155,13 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +180,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -191,8 +189,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -209,11 +207,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -244,8 +242,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -2,6 +2,7 @@ package nebula
import ( import (
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
@@ -191,17 +192,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version() certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
vpnNetworks := remoteCert.Certificate.Networks()
anyVpnAddrsInCommon := false for _, network := range remoteCert.Certificate.Networks() {
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddr := network.Addr()
for i, network := range vpnNetworks { if f.myVpnAddrsTable.Contains(vpnAddr) {
if f.myVpnAddrsTable.Contains(network.Addr()) { f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -209,10 +210,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return return
} }
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) { // vpnAddrs outside our vpn networks are of no use to us, filter them out
anyVpnAddrsInCommon = true if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue
} }
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return
} }
if addr.IsValid() { if addr.IsValid() {
@@ -249,30 +264,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}, },
} }
msgRxL := f.l.WithFields(m{ f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
"vpnAddrs": vpnAddrs, WithField("certName", certName).
"udpAddr": addr, WithField("certVersion", certVersion).
"certName": certName, WithField("fingerprint", fingerprint).
"certVersion": certVersion, WithField("issuer", issuer).
"fingerprint": fingerprint, WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
"issuer": issuer, WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
"initiatorIndex": hs.Details.InitiatorIndex, Info("Handshake message received")
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil { if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available") Error("Unable to handshake with host because no certificate handshake bytes is available")
return return
} }
@@ -330,7 +341,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil { if err != nil {
@@ -571,22 +582,31 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
correctHostResponded := false var vpnAddrs []netip.Addr
anyVpnAddrsInCommon := false var filteredNetworks []netip.Prefix
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for _, network := range vpnNetworks {
for i, network := range vpnNetworks { // vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddrs[i] = network.Addr() vpnAddr := network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) { if !f.myVpnNetworksTable.Contains(vpnAddr) {
anyVpnAddrsInCommon = true continue
} }
if hostinfo.vpnAddrs[0] == network.Addr() {
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not? filteredNetworks = append(filteredNetworks, network)
correctHostResponded = true vpnAddrs = append(vpnAddrs, vpnAddr)
} }
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
} }
// Ensure the right host responded // Ensure the right host responded
if !correctHostResponded { if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr). WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
@@ -598,7 +618,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip // Create a new hostinfo/handshake for the intended vpn ip
//TODO is hostinfo.vpnAddrs[0] always the address to use?
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
@@ -625,7 +644,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -633,17 +652,12 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration). WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore)) WithField("sentCachedPackets", len(hh.packetStore)).
if anyVpnAddrsInCommon { Info("Handshake message received")
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
// Build up the radix for the firewall if we have subnets in the cert // Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)

View File

@@ -269,12 +269,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to // Don't relay to myself
if relay == vpnIp { if relay == vpnIp {
continue continue
} }
// Don't relay to myself // Don't relay through the host I'm trying to connect to
if hm.f.myVpnAddrsTable.Contains(relay) { if hm.f.myVpnAddrsTable.Contains(relay) {
continue continue
} }

View File

@@ -212,18 +212,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.relayForByIdx[idx] = r rs.relayForByIdx[idx] = r
} }
type NetworkType uint8
const (
NetworkTypeUnknown NetworkType = iota
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
NetworkTypeVPN
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
)
type HostInfo struct { type HostInfo struct {
remote netip.AddrPort remote netip.AddrPort
remotes *RemoteList remotes *RemoteList
@@ -237,8 +225,8 @@ type HostInfo struct {
// vpn networks but were removed because they are not usable // vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr vpnAddrs []netip.Addr
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. // networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[NetworkType] networks *bart.Lite
relayState RelayState relayState RelayState
// HandshakePacket records the packets used to create this hostinfo // HandshakePacket records the packets used to create this hostinfo
@@ -742,26 +730,20 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false return false
} }
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up. func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) { if len(networks) == 1 && len(unsafeNetworks) == 0 {
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { // Simple case, no CIDRTree needed
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) { return
return // Simple case, no BART needed
}
} }
i.networks = new(bart.Table[NetworkType]) i.networks = new(bart.Lite)
for _, network := range c.Networks() { for _, network := range networks {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
if myVpnNetworksTable.Contains(network.Addr()) { i.networks.Insert(nprefix)
i.networks.Insert(nprefix, NetworkTypeVPN)
} else {
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
}
} }
for _, network := range c.UnsafeNetworks() { for _, network := range unsafeNetworks {
i.networks.Insert(network, NetworkTypeUnsafe) i.networks.Insert(network)
} }
} }

154
inside.go
View File

@@ -11,6 +11,149 @@ import (
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
// consumeInsidePackets processes multiple packets in a batch for improved performance
// packets: slice of packet buffers to process
// sizes: slice of packet sizes
// count: number of packets to process
// outs: slice of output buffers (one per packet) with virtio headroom
// q: queue index
// localCache: firewall conntrack cache
// batchPackets: pre-allocated slice for accumulating encrypted packets
// batchAddrs: pre-allocated slice for accumulating destination addresses
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
// Reusable per-packet state
fwPacket := &firewall.Packet{}
// Reset batch accumulation slices (reuse capacity)
*batchPackets = (*batchPackets)[:0]
*batchAddrs = (*batchAddrs)[:0]
// Process each packet in the batch
for i := 0; i < count; i++ {
packet := packets[i][:sizes[i]]
out := outs[i]
// Inline the consumeInsidePacket logic for better performance
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
}
continue
}
// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
continue
}
}
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
// Immediately forward packets from self to self.
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
}
}
continue
}
// Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
continue
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
}
continue
}
if !ready {
continue
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
continue
}
// Encrypt and prepare packet for batch sending
ci := hostinfo.ConnectionState
if ci.eKey == nil {
continue
}
// Check if this needs relay - if so, send immediately and skip batching
useRelay := !hostinfo.remote.IsValid()
if useRelay {
// Handle relay sends individually (less common path)
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
continue
}
// Encrypt the packet for batch sending
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query lighthouse if needed
if hostinfo.lastRebindCount != f.rebindCount {
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("counter", c).
Error("Failed to encrypt outgoing packet")
continue
}
// Add to batch
*batchPackets = append(*batchPackets, out)
*batchAddrs = append(*batchAddrs, hostinfo.remote)
}
// Send all accumulated packets in one batch
if len(*batchPackets) > 0 {
batchSize := len(*batchPackets)
f.batchMetrics.udpWriteSize.Update(int64(batchSize))
n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch")
}
}
}
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
@@ -120,10 +263,9 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
} }
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
// it does not check if it is within our vpn networks!
func (f *Interface) Handshake(vpnAddr netip.Addr) { func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.handshakeManager.GetOrHandshake(vpnAddr, nil) f.getOrHandshakeNoRouting(vpnAddr, nil)
} }
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
@@ -139,6 +281,7 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
destinationAddr := fwPacket.RemoteAddr destinationAddr := fwPacket.RemoteAddr
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
@@ -231,10 +374,9 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
} }
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr. // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
@@ -22,6 +21,7 @@ import (
) )
const mtu = 9001 const mtu = 9001
const virtioNetHdrLen = overlay.VirtioNetHdrLen
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
@@ -50,6 +50,13 @@ type InterfaceConfig struct {
l *logrus.Logger l *logrus.Logger
} }
type batchMetrics struct {
udpReadSize metrics.Histogram
tunReadSize metrics.Histogram
udpWriteSize metrics.Histogram
tunWriteSize metrics.Histogram
}
type Interface struct { type Interface struct {
hostMap *HostMap hostMap *HostMap
outside udp.Conn outside udp.Conn
@@ -86,11 +93,12 @@ type Interface struct {
conntrackCacheTimeout time.Duration conntrackCacheTimeout time.Duration
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []overlay.BatchReadWriter
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
batchMetrics *batchMetrics
l *logrus.Logger l *logrus.Logger
} }
@@ -177,7 +185,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines, routines: c.routines,
version: c.version, version: c.version,
writers: make([]udp.Conn, c.routines), writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines), readers: make([]overlay.BatchReadWriter, c.routines),
myVpnNetworks: cs.myVpnNetworks, myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable, myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs, myVpnAddrs: cs.myVpnAddrs,
@@ -193,6 +201,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil), sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
}, },
batchMetrics: &batchMetrics{
udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)),
tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)),
udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)),
tunWriteSize: metrics.GetOrRegisterHistogram("batch.tun_write_size", nil, metrics.NewUniformSample(1024)),
},
l: c.l, l: c.l,
} }
@@ -222,17 +236,10 @@ func (f *Interface) activate() {
WithField("boringcrypto", boringEnabled()). WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active") Info("Nebula interface is active")
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
f.routines = 1
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
}
}
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues // Prepare n tun queues
var reader io.ReadWriteCloser = f.inside var reader overlay.BatchReadWriter = f.inside
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() reader, err = f.inside.NewMultiQueueReader()
@@ -273,39 +280,69 @@ func (f *Interface) listenOut(i int) {
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
// Pre-allocate output buffers for batch processing
batchSize := li.BatchSize()
outs := make([][]byte, batchSize)
for idx := range outs {
// Allocate full buffer with virtio header space
outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU)
}
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l))
}) })
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) {
runtime.LockOSThread() runtime.LockOSThread()
packet := make([]byte, mtu) batchSize := reader.BatchSize()
out := make([]byte, mtu)
fwPacket := &firewall.Packet{} // Allocate buffers for batch reading
nb := make([]byte, 12, 12) bufs := make([][]byte, batchSize)
for idx := range bufs {
bufs[idx] = make([]byte, mtu)
}
sizes := make([]int, batchSize)
// Allocate output buffers for batch processing (one per packet)
// Each has virtio header headroom to avoid copies on write
outs := make([][]byte, batchSize)
for idx := range outs {
outBuf := make([]byte, virtioNetHdrLen+mtu)
outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom
}
// Pre-allocate batch accumulation buffers for sending
batchPackets := make([][]byte, 0, batchSize)
batchAddrs := make([]netip.AddrPort, 0, batchSize)
// Pre-allocate nonce buffer (reused for all encryptions)
nb := make([]byte, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for { for {
n, err := reader.Read(packet) n, err := reader.BatchRead(bufs, sizes)
if err != nil { if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() { if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return return
} }
f.l.WithError(err).Error("Error while reading outbound packet") f.l.WithError(err).Error("Error while batch reading outbound packets")
// This only seems to happen when something fatal happens to the fd, so exit. // This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2) os.Exit(2)
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) f.batchMetrics.tunReadSize.Update(int64(n))
// Process all packets in the batch at once
f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs)
} }
} }

View File

@@ -360,8 +360,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
} }
if !lh.myVpnNetworksTable.Contains(addr) { if !lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
} }
out[i] = addr out[i] = addr
} }
@@ -432,8 +431,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
} }
if !lh.myVpnNetworksTable.Contains(vpnAddr) { if !lh.myVpnNetworksTable.Contains(vpnAddr) {
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
} }
vals, ok := v.([]any) vals, ok := v.([]any)

View File

@@ -14,7 +14,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func TestOldIPv4Only(t *testing.T) { func TestOldIPv4Only(t *testing.T) {

25
main.go
View File

@@ -5,8 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime/debug"
"strings"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -15,7 +13,7 @@ import (
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
type m = map[string]any type m = map[string]any
@@ -29,10 +27,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
}() }()
if buildVersion == "" {
buildVersion = moduleVersion()
}
l := logger l := logger
l.Formatter = &logrus.TextFormatter{ l.Formatter = &logrus.TextFormatter{
FullTimestamp: true, FullTimestamp: true,
@@ -171,7 +165,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i := 0; i < routines; i++ { for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
if err != nil { if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
} }
@@ -302,18 +296,3 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
connManager.Start, connManager.Start,
}, nil }, nil
} }
func moduleVersion() string {
info, ok := debug.ReadBuildInfo()
if !ok {
return ""
}
for _, dep := range info.Deps {
if dep.Path == "github.com/slackhq/nebula" {
return strings.TrimPrefix(dep.Version, "v")
}
}
return ""
}

View File

@@ -95,8 +95,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch relay.Type { switch relay.Type {
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache)
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -138,7 +137,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return return
} }
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) lhf.HandleRequest(ip, hostinfo.vpnAddrs, d[virtioNetHdrLen:], f)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
@@ -160,7 +159,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// This testRequest might be from TryPromoteBest, so we should roam // This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding // to the new IP address before responding
f.handleHostRoaming(hostinfo, ip) f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) f.send(header.Test, header.TestReply, ci, hostinfo, d[virtioNetHdrLen:], nb, out)
} }
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
@@ -203,7 +202,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return return
} }
f.relayManager.HandleControlMsg(hostinfo, d, f) f.relayManager.HandleControlMsg(hostinfo, d[virtioNetHdrLen:], f)
default: default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -474,9 +473,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false return false
} }
err = newPacket(out, true, fwPacket) packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out). hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).
Warnf("Error while validating inbound packet") Warnf("Error while validating inbound packet")
return false return false
} }
@@ -491,7 +492,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
if dropReason != nil { if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in // This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket). hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason). WithField("reason", dropReason).
@@ -548,3 +549,108 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
// We also delete it from pending hostmap to allow for fast reconnect. // We also delete it from pending hostmap to allow for fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
} }
// readOutsidePacketsBatch processes multiple packets received from UDP in a batch
// and writes all successfully decrypted packets to TUN in a single operation
func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) {
// Pre-allocate slice for accumulating successful decryptions
tunPackets := make([][]byte, 0, count)
for i := 0; i < count; i++ {
payload := payloads[i]
addr := addrs[i]
out := outs[i]
// Parse header
err := h.Parse(payload)
if err != nil {
if len(payload) > 1 {
f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err)
}
continue
}
if addr.IsValid() {
if f.myVpnNetworksTable.Contains(addr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
}
continue
}
}
var hostinfo *HostInfo
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, addr, h) {
continue
}
switch h.Subtype {
case header.MessageNone:
// Decrypt packet
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
continue
}
packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet")
continue
}
if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet")
continue
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet")
}
continue
}
f.connectionManager.In(hostinfo)
// Add to batch for TUN write
tunPackets = append(tunPackets, out)
case header.MessageRelay:
// Skip relay packets in batch mode for now (less common path)
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
default:
hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype)
}
default:
// Handle non-Message types using single-packet path
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
}
}
if len(tunPackets) > 0 {
n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun")
}
f.batchMetrics.tunWriteSize.Update(int64(len(tunPackets)))
}
}

View File

@@ -7,12 +7,25 @@ import (
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
type Device interface { // BatchReadWriter extends io.ReadWriteCloser with batch I/O operations
type BatchReadWriter interface {
io.ReadWriteCloser io.ReadWriteCloser
// BatchRead reads multiple packets at once
BatchRead(bufs [][]byte, sizes []int) (int, error)
// WriteBatch writes multiple packets at once
WriteBatch(bufs [][]byte, offset int) (int, error)
// BatchSize returns the optimal batch size for this device
BatchSize() int
}
type Device interface {
BatchReadWriter
Activate() error Activate() error
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool NewMultiQueueReader() (BatchReadWriter, error)
NewMultiQueueReader() (io.ReadWriteCloser, error)
} }

View File

@@ -11,6 +11,7 @@ import (
) )
const DefaultMTU = 1300 const DefaultMTU = 1300
const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure
// TODO: We may be able to remove routines // TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)

View File

@@ -95,10 +95,29 @@ func (t *tun) Name() string {
return "android" return "android"
} }
func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android") return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
} }
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -549,10 +549,32 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }
// BatchRead reads a single packet (batch size 1 for non-Linux platforms)
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for non-Linux platforms)
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for non-Linux platforms (no batching)
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -105,12 +105,34 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (t *disabledTun) SupportsMultiqueue() bool { func (t *disabledTun) NewMultiQueueReader() (BatchReadWriter, error) {
return true return t, nil
} }
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { // BatchRead reads a single packet (batch size 1 for disabled tun)
return t, nil func (t *disabledTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for disabled tun)
func (t *disabledTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for disabled tun (no batching)
func (t *disabledTun) BatchSize() int {
return 1
} }
func (t *disabledTun) Close() error { func (t *disabledTun) Close() error {

View File

@@ -450,12 +450,34 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return false return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { // BatchRead reads a single packet (batch size 1 for FreeBSD)
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for FreeBSD)
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for FreeBSD (no batching)
func (t *tun) BatchSize() int {
return 1
} }
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {

View File

@@ -151,10 +151,29 @@ func (t *tun) Name() string {
return "iOS" return "iOS"
} }
func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
} }
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -9,7 +9,6 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe" "unsafe"
@@ -21,10 +20,12 @@ import (
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
) )
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
wgDevice wgtun.Device
fd int fd int
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
@@ -65,59 +66,154 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") // This allows multiqueue readers to use the same wireguard Device batching as the main device
type wgDeviceWrapper struct {
dev wgtun.Device
buf []byte // Reusable buffer for single packet reads
}
func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
// Use wireguard Device's batch API for single packet
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := w.dev.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := w.dev.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) {
// Pass all buffers to WireGuard's batch write
return w.dev.Write(bufs, offset)
}
func (w *wgDeviceWrapper) Close() error {
return w.dev.Close()
}
// BatchRead implements batching for multiqueue readers
func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) {
// The zero here is offset.
return w.dev.Read(bufs, sizes, 0)
}
// BatchSize returns the optimal batch size
func (w *wgDeviceWrapper) BatchSize() int {
return w.dev.BatchSize()
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd)
if err != nil {
return nil, fmt.Errorf("failed to create TUN from FD: %w", err)
}
file := wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
_ = wgDev.Close()
return nil, err return nil, err
} }
t.Device = "tun0" t.wgDevice = wgDev
t.Device = name
return t, nil return t, nil
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) // Check if /dev/net/tun exists, create if needed (for docker containers)
if err != nil { if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker) if err := os.MkdirAll("/dev/net", 0755); err != nil {
if os.IsNotExist(err) {
err = os.MkdirAll("/dev/net", 0755)
if err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
} }
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
if err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
} }
}
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) devName := c.GetString("tun.dev", "")
mtu := c.GetInt("tun.mtu", DefaultMTU)
// Create TUN device manually to support multiqueue
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
}
} else {
return nil, err return nil, err
} }
}
var req ifReq var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
if multiqueue { if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE req.Flags |= unix.IFF_MULTI_QUEUE
} }
copy(req.Name[:], c.GetString("tun.dev", "")) copy(req.Name[:], devName)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err return nil, err
} }
name := strings.Trim(string(req.Name[:]), "\x00")
// Set nonblocking
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Enable TCP and UDP offload (TSO/GRO) for performance
// This allows the kernel to handle segmentation/coalescing
const (
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
offloads := tunTCPOffloads | tunUDPOffloads
if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil {
// Log warning but don't fail - offload is optional
l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced")
}
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
// Create wireguard device from file descriptor
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create TUN from file: %w", err)
}
name, err := wgDev.Name()
if err != nil {
_ = wgDev.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
}
// file is now owned by wgDev, get a new reference
file = wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
_ = wgDev.Close()
return nil, err return nil, err
} }
t.wgDevice = wgDev
t.Device = name t.Device = name
return t, nil return t, nil
@@ -216,26 +312,44 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return true
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var req ifReq var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) // MUST match the flags used in newTun - includes IFF_VNET_HDR
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device) copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err return nil, err
} }
// Set nonblocking mode - CRITICAL for proper netpoller integration
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Get MTU from main device
mtu := t.MaxMTU
if mtu == 0 {
mtu = DefaultMTU
}
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
return file, nil // Create wireguard Device from the file descriptor (just like the main device)
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err)
}
// Return a wrapper that uses the wireguard Device for all I/O
return &wgDeviceWrapper{dev: wgDev}, nil
} }
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -243,7 +357,68 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r return r
} }
func (t *tun) Read(b []byte) (int, error) {
if t.wgDevice != nil {
// Use wireguard device which handles virtio headers internally
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := t.wgDevice.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
// Fallback: direct read from file (shouldn't happen in normal operation)
return t.ReadWriteCloser.Read(b)
}
// BatchRead reads multiple packets at once for improved performance
// bufs: slice of buffers to read into
// sizes: slice that will be filled with packet sizes
// Returns number of packets read
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Read(bufs, sizes, 0)
}
// Fallback: single packet read
n, err := t.ReadWriteCloser.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// BatchSize returns the optimal number of packets to read/write in a batch
func (t *tun) BatchSize() int {
if t.wgDevice != nil {
return t.wgDevice.BatchSize()
}
return 1
}
func (t *tun) Write(b []byte) (int, error) { func (t *tun) Write(b []byte) (int, error) {
if t.wgDevice != nil {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
// Fallback: direct write (shouldn't happen in normal operation)
var nn int var nn int
maximum := len(b) maximum := len(b)
@@ -266,6 +441,22 @@ func (t *tun) Write(b []byte) (int, error) {
} }
} }
// WriteBatch writes multiple packets to the TUN device in a single syscall
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Write(bufs, offset)
}
// Fallback: write individually (shouldn't happen in normal operation)
for i, buf := range bufs {
_, err := t.Write(buf)
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) deviceBytes() (o [16]byte) { func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device { for i, c := range t.Device {
o[i] = byte(c) o[i] = byte(c)
@@ -586,42 +777,48 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
} }
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device) link, err := netlink.LinkByName(t.Device)
if err != nil { if err != nil {
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways return gateways
} }
// If this route is relevant to our interface and there is a gateway then add it // If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index { if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := getGatewayAddr(r.Gw, r.Via) gwAddr, ok := netip.AddrFromSlice(r.Gw)
if ok { if !ok {
if t.isGatewayInVpnNetworks(gwAddr) { t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else { } else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore // Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
}
} else { } else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
} }
} }
for _, p := range r.MultiPath { for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it // If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index { if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := getGatewayAddr(p.Gw, p.Via) gwAddr, ok := netip.AddrFromSlice(p.Gw)
if ok { if !ok {
if t.isGatewayInVpnNetworks(gwAddr) { t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else { } else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore // Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
}
} else { } else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") // p.Hops+1 = weight of the route
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
} }
} }
} }
@@ -630,27 +827,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
return gateways return gateways
} }
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
// Try to use the old RTA_GATEWAY first
gwAddr, ok := netip.AddrFromSlice(gw)
if !ok {
// Fallback to the new RTA_VIA
rVia, ok := via.(*netlink.Via)
if ok {
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
}
}
if gwAddr.IsValid() {
gwAddr = gwAddr.Unmap()
return gwAddr, true
}
return netip.Addr{}, false
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) { func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route) gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 { if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required. // No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways") t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
@@ -689,6 +869,10 @@ func (t *tun) Close() error {
close(t.routeChan) close(t.routeChan)
} }
if t.wgDevice != nil {
_ = t.wgDevice.Close()
}
if t.ReadWriteCloser != nil { if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close() _ = t.ReadWriteCloser.Close()
} }

View File

@@ -390,12 +390,31 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return false return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
} }
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {

View File

@@ -310,12 +310,31 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return false return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
} }
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {

View File

@@ -132,10 +132,29 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func (t *TestTun) SupportsMultiqueue() bool { func (t *TestTun) NewMultiQueueReader() (BatchReadWriter, error) {
return false
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented") return nil, fmt.Errorf("TODO: multiqueue not implemented")
} }
func (t *TestTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *TestTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *TestTun) BatchSize() int {
return 1
}

View File

@@ -6,7 +6,6 @@ package overlay
import ( import (
"crypto" "crypto"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
@@ -234,12 +233,34 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0) return t.tun.Write(b, 0)
} }
func (t *winTun) SupportsMultiqueue() bool { func (t *winTun) NewMultiQueueReader() (BatchReadWriter, error) {
return false return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
} }
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { // BatchRead reads a single packet (batch size 1 for Windows)
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") func (t *winTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for Windows)
func (t *winTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for Windows (no batching)
func (t *winTun) BatchSize() int {
return 1
} }
func (t *winTun) Close() error { func (t *winTun) Close() error {

View File

@@ -46,12 +46,34 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)} return routing.Gateways{routing.NewGateway(ip, 1)}
} }
func (d *UserDevice) SupportsMultiqueue() bool { func (d *UserDevice) NewMultiQueueReader() (BatchReadWriter, error) {
return true return d, nil
} }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { // BatchRead reads a single packet (batch size 1 for UserDevice)
return d, nil func (d *UserDevice) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := d.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for UserDevice)
func (d *UserDevice) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := d.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for UserDevice (no batching)
func (d *UserDevice) BatchSize() int {
return 1
} }
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {

8
pki.go
View File

@@ -523,14 +523,10 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
} }
bl := c.GetStringSlice("pki.blocklist", []string{}) for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
if len(bl) > 0 { l.WithField("fingerprint", fp).Info("Blocklisting cert")
for _, fp := range bl {
caPool.BlocklistFingerprint(fp) caPool.BlocklistFingerprint(fp)
} }
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
}
return caPool, nil return caPool, nil
} }

View File

@@ -16,8 +16,8 @@ import (
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"go.yaml.in/yaml/v3"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
) )
type m = map[string]any type m = map[string]any

View File

@@ -6,6 +6,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
_ "net/http/pprof"
"runtime" "runtime"
"strconv" "strconv"
"time" "time"

View File

@@ -34,10 +34,6 @@ func (NoopTun) Write([]byte) (int, error) {
return 0, nil return 0, nil
} }
func (NoopTun) SupportsMultiqueue() bool {
return false
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, errors.New("unsupported") return nil, errors.New("unsupported")
} }

View File

@@ -13,13 +13,21 @@ type EncReader func(
payload []byte, payload []byte,
) )
type EncBatchReader func(
addrs []netip.AddrPort,
payloads [][]byte,
count int,
)
type Conn interface { type Conn interface {
Rebind() error Rebind() error
LocalAddr() (netip.AddrPort, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) ListenOut(r EncReader)
ListenOutBatch(r EncBatchReader)
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error)
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
SupportsMultipleReaders() bool BatchSize() int
Close() error Close() error
} }
@@ -34,15 +42,21 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
func (NoopConn) ListenOut(_ EncReader) { func (NoopConn) ListenOut(_ EncReader) {
return return
} }
func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) ListenOutBatch(_ EncBatchReader) {
return false return
} }
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }
func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) {
return 0, nil
}
func (NoopConn) ReloadConfig(_ *config.C) { func (NoopConn) ReloadConfig(_ *config.C) {
return return
} }
func (NoopConn) BatchSize() int {
return 1
}
func (NoopConn) Close() error { func (NoopConn) Close() error {
return nil return nil
} }

View File

@@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
return ErrInvalidIPv6RemoteForSocket return ErrInvalidIPv6RemoteForSocket
} }
var rsa unix.RawSockaddrInet4 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As4() rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa) sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4 addrLen = syscall.SizeofSockaddrInet4
@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
} }
} }
// WriteMulti sends multiple packets - fallback implementation without sendmmsg
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) { func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr() a := u.UDPConn.LocalAddr()
@@ -184,8 +195,32 @@ func (u *StdConn) ListenOut(r EncReader) {
} }
} }
func (u *StdConn) SupportsMultipleReaders() bool { // ListenOutBatch - fallback to single-packet reads for Darwin
return false func (u *StdConn) ListenOutBatch(r EncBatchReader) {
buffer := make([]byte, MTU)
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
// Just read one packet at a time and call batch callback with count=1
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
}
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
payloads[0] = buffer[:n]
r(addrs, payloads, 1)
}
}
func (u *StdConn) BatchSize() int {
return 1
} }
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {

View File

@@ -86,6 +86,41 @@ func (u *GenericConn) ListenOut(r EncReader) {
} }
} }
func (u *GenericConn) SupportsMultipleReaders() bool { // ListenOutBatch - fallback to single-packet reads for generic platforms
return false func (u *GenericConn) ListenOutBatch(r EncBatchReader) {
buffer := make([]byte, MTU)
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
// Just read one packet at a time and call batch callback with count=1
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
payloads[0] = buffer[:n]
r(addrs, payloads, 1)
}
}
// WriteMulti sends multiple packets - fallback implementation
func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *GenericConn) BatchSize() int {
return 1
}
func (u *GenericConn) Rebind() error {
return nil
} }

View File

@@ -22,6 +22,11 @@ type StdConn struct {
isV4 bool isV4 bool
l *logrus.Logger l *logrus.Logger
batch int batch int
// Pre-allocated buffers for batch writes (sized for IPv6, works for both)
writeMsgs []rawMessage
writeIovecs []iovec
writeNames [][]byte
} }
func maybeIPV4(ip net.IP) (net.IP, bool) { func maybeIPV4(ip net.IP) (net.IP, bool) {
@@ -69,11 +74,26 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return nil, fmt.Errorf("unable to bind to socket: %s", err) return nil, fmt.Errorf("unable to bind to socket: %s", err)
} }
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}
// Pre-allocate write message structures for batching (sized for IPv6, works for both)
c.writeMsgs = make([]rawMessage, batch)
c.writeIovecs = make([]iovec, batch)
c.writeNames = make([][]byte, batch)
for i := range c.writeMsgs {
// Allocate for IPv6 size (larger than IPv4, works for both)
c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
// Point to the iovec in the slice
c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i]
c.writeMsgs[i].Hdr.Iovlen = 1
c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0]
// Namelen will be set appropriately in writeMulti4/writeMulti6
} }
func (u *StdConn) SupportsMultipleReaders() bool { return c, err
return true
} }
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
@@ -131,6 +151,8 @@ func (u *StdConn) ListenOut(r EncReader) {
read = u.ReadSingle read = u.ReadSingle
} }
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
for { for {
n, err := read(msgs) n, err := read(msgs)
if err != nil { if err != nil {
@@ -138,6 +160,8 @@ func (u *StdConn) ListenOut(r EncReader) {
return return
} }
udpBatchHist.Update(int64(n))
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 { if u.isV4 {
@@ -150,6 +174,46 @@ func (u *StdConn) ListenOut(r EncReader) {
} }
} }
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti
if u.batch == 1 {
read = u.ReadSingle
}
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
// Pre-allocate slices for batch callback
addrs := make([]netip.AddrPort, u.batch)
payloads := make([][]byte, u.batch)
for {
n, err := read(msgs)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
udpBatchHist.Update(int64(n))
// Prepare batch data
for i := 0; i < n; i++ {
if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8])
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
}
addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
payloads[i] = buffers[i][:msgs[i].Len]
}
// Call batch callback with all packets
r(addrs, payloads, n)
}
}
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
for { for {
n, _, err := unix.Syscall6( n, _, err := unix.Syscall6(
@@ -198,6 +262,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip) return u.writeTo6(b, ip)
} }
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
if len(packets) != len(addrs) {
return 0, fmt.Errorf("packets and addrs length mismatch")
}
if len(packets) == 0 {
return 0, nil
}
if u.isV4 {
return u.writeMulti4(packets, addrs)
}
return u.writeMulti6(packets, addrs)
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6 rsa.Family = unix.AF_INET6
@@ -252,6 +329,123 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
} }
} }
func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) {
sent := 0
for sent < len(packets) {
// Determine batch size based on remaining packets and buffer capacity
batchSize := len(packets) - sent
if batchSize > len(u.writeMsgs) {
batchSize = len(u.writeMsgs)
}
// Use pre-allocated buffers
msgs := u.writeMsgs[:batchSize]
iovecs := u.writeIovecs[:batchSize]
names := u.writeNames[:batchSize]
// Setup message structures for this batch
for i := 0; i < batchSize; i++ {
pktIdx := sent + i
if !addrs[pktIdx].Addr().Is4() {
return sent + i, ErrInvalidIPv6RemoteForSocket
}
// Setup the packet buffer
iovecs[i].Base = &packets[pktIdx][0]
iovecs[i].Len = uint(len(packets[pktIdx]))
// Setup the destination address
rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0]))
rsa.Family = unix.AF_INET
rsa.Addr = addrs[pktIdx].Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
// Set the appropriate address length for IPv4
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4
}
// Send this batch
nsent, _, err := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[0])),
uintptr(batchSize),
0,
0,
0,
)
if err != 0 {
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
}
sent += int(nsent)
if int(nsent) < batchSize {
// Couldn't send all packets in batch, return what we sent
return sent, nil
}
}
return sent, nil
}
func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) {
sent := 0
for sent < len(packets) {
// Determine batch size based on remaining packets and buffer capacity
batchSize := len(packets) - sent
if batchSize > len(u.writeMsgs) {
batchSize = len(u.writeMsgs)
}
// Use pre-allocated buffers
msgs := u.writeMsgs[:batchSize]
iovecs := u.writeIovecs[:batchSize]
names := u.writeNames[:batchSize]
// Setup message structures for this batch
for i := 0; i < batchSize; i++ {
pktIdx := sent + i
// Setup the packet buffer
iovecs[i].Base = &packets[pktIdx][0]
iovecs[i].Len = uint(len(packets[pktIdx]))
// Setup the destination address
rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0]))
rsa.Family = unix.AF_INET6
rsa.Addr = addrs[pktIdx].Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
// Set the appropriate address length for IPv6
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6
}
// Send this batch
nsent, _, err := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[0])),
uintptr(batchSize),
0,
0,
0,
)
if err != 0 {
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
}
sent += int(nsent)
if int(nsent) < batchSize {
// Couldn't send all packets in batch, return what we sent
return sent, nil
}
}
return sent, nil
}
func (u *StdConn) ReloadConfig(c *config.C) { func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0) b := c.GetInt("listen.read_buffer", 0)
if b > 0 { if b > 0 {
@@ -309,6 +503,10 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
return nil return nil
} }
func (u *StdConn) BatchSize() int {
return u.batch
}
func (u *StdConn) Close() error { func (u *StdConn) Close() error {
return syscall.Close(u.sysFd) return syscall.Close(u.sysFd)
} }

View File

@@ -12,7 +12,7 @@ import (
type iovec struct { type iovec struct {
Base *byte Base *byte
Len uint32 Len uint
} }
type msghdr struct { type msghdr struct {
@@ -40,7 +40,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, {Base: &buffers[i][0], Len: uint(len(buffers[i]))},
} }
msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iov = &vs[0]

View File

@@ -12,7 +12,7 @@ import (
type iovec struct { type iovec struct {
Base *byte Base *byte
Len uint64 Len uint
} }
type msghdr struct { type msghdr struct {
@@ -43,7 +43,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, {Base: &buffers[i][0], Len: uint(len(buffers[i]))},
} }
msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iov = &vs[0]

View File

@@ -315,10 +315,6 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
} }
func (u *RIOConn) SupportsMultipleReaders() bool {
return false
}
func (u *RIOConn) Rebind() error { func (u *RIOConn) Rebind() error {
return nil return nil
} }

View File

@@ -116,6 +116,31 @@ func (u *TesterConn) ListenOut(r EncReader) {
} }
} }
func (u *TesterConn) ListenOutBatch(r EncBatchReader) {
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
p, ok := <-u.RxPackets
if !ok {
return
}
addrs[0] = p.From
payloads[0] = p.Data
r(addrs, payloads, 1)
}
}
func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *TesterConn) ReloadConfig(*config.C) {} func (u *TesterConn) ReloadConfig(*config.C) {}
func NewUDPStatsEmitter(_ []Conn) func() { func NewUDPStatsEmitter(_ []Conn) func() {
@@ -127,8 +152,8 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil return u.Addr, nil
} }
func (u *TesterConn) SupportsMultipleReaders() bool { func (u *TesterConn) BatchSize() int {
return false return 1
} }
func (u *TesterConn) Rebind() error { func (u *TesterConn) Rebind() error {