Compare commits

..

9 Commits

Author SHA1 Message Date
Jay Wren
8d4dd26484 try to fix windows? 2025-10-23 16:50:10 -04:00
Jay Wren
0a94f9f990 fix tun_linux_test.go & a nit 2025-10-23 16:16:40 -04:00
Jay Wren
433c531ae4 put linux tx_queue back 2025-10-23 16:06:13 -04:00
Jay Wren
4c0aad1b1f put back the previous darwin interface naming 2025-10-23 16:04:52 -04:00
Jay Wren
c8b0281736 gofmt 2025-10-23 14:50:50 -04:00
Jay Wren
8281b1699f fix BatchRead interface & make batch size configurable 2025-10-23 14:37:26 -04:00
Jay Wren
0827a6f1c5 claude suggests this bound checking optimization
it is about 4.4% better on the bench-cpu-long cpu usage test
2025-10-23 13:27:37 -04:00
Jay Wren
273119638d claude did not want netbsd 2025-10-22 17:32:03 -04:00
Jay Wren
484de41b58 claude first stab and first cleanup - ugh 2025-10-22 16:49:26 -04:00
64 changed files with 1879 additions and 3760 deletions

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -10,7 +10,7 @@ jobs:
name: Build Linux/BSD All name: Build Linux/BSD All
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release mv build/*.tar.gz release
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: linux-latest name: linux-latest
path: release path: release
@@ -33,7 +33,7 @@ jobs:
name: Build Windows name: Build Windows
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -55,7 +55,7 @@ jobs:
mv dist\windows\wintun build\dist\windows\ mv dist\windows\wintun build\dist\windows\
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: windows-latest name: windows-latest
path: build path: build
@@ -66,7 +66,7 @@ jobs:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -104,7 +104,7 @@ jobs:
fi fi
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: darwin-latest name: darwin-latest
path: ./release/* path: ./release/*
@@ -124,11 +124,11 @@ jobs:
# be overwritten # be overwritten
- name: Checkout code - name: Checkout code
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Download artifacts - name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/download-artifact@v6 uses: actions/download-artifact@v4
with: with:
name: linux-latest name: linux-latest
path: artifacts path: artifacts
@@ -160,10 +160,10 @@ jobs:
needs: [build-linux, build-darwin, build-windows] needs: [build-linux, build-darwin, build-windows]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- name: Download artifacts - name: Download artifacts
uses: actions/download-artifact@v6 uses: actions/download-artifact@v4
with: with:
path: artifacts path: artifacts

View File

@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -32,7 +32,7 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v9 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.5 version: v2.5
@@ -45,7 +45,7 @@ jobs:
- name: Build test mobile - name: Build test mobile
run: make build-test-mobile run: make build-test-mobile
- uses: actions/upload-artifact@v5 - uses: actions/upload-artifact@v4
with: with:
name: e2e packet flow linux-latest name: e2e packet flow linux-latest
path: e2e/mermaid/linux-latest path: e2e/mermaid/linux-latest
@@ -56,7 +56,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -77,7 +77,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -98,7 +98,7 @@ jobs:
os: [windows-latest, macos-latest] os: [windows-latest, macos-latest]
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -115,7 +115,7 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v9 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.5 version: v2.5
@@ -125,7 +125,7 @@ jobs:
- name: End 2 end - name: End 2 end
run: make e2evv run: make e2evv
- uses: actions/upload-artifact@v5 - uses: actions/upload-artifact@v4
with: with:
name: e2e packet flow ${{ matrix.os }} name: e2e packet flow ${{ matrix.os }}
path: e2e/mermaid/${{ matrix.os }} path: e2e/mermaid/${{ matrix.os }}

107
bits.go
View File

@@ -9,13 +9,14 @@ type Bits struct {
length uint64 length uint64
current uint64 current uint64
bits []bool bits []bool
firstSeen bool
lostCounter metrics.Counter lostCounter metrics.Counter
dupeCounter metrics.Counter dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter outOfWindowCounter metrics.Counter
} }
func NewBits(bits uint64) *Bits { func NewBits(bits uint64) *Bits {
b := &Bits{ return &Bits{
length: bits, length: bits,
bits: make([]bool, bits, bits), bits: make([]bool, bits, bits),
current: 0, current: 0,
@@ -23,37 +24,34 @@ func NewBits(bits uint64) *Bits {
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil), outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
} }
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
return b
} }
func (b *Bits) Check(l *logrus.Logger, i uint64) bool { func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true. // If i is the next number, return true.
if i > b.current { if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
return true return true
} }
// If i is within the window, check if it's been set already. // If i is within the window, check if it's been set already. The first window will fail this check
if i > b.current-b.length || i < b.length && b.current < b.length { if i > b.current-b.length {
return !b.bits[i%b.length]
}
// If i is within the first window
if i < b.length {
return !b.bits[i%b.length] return !b.bits[i%b.length]
} }
// Not within the window // Not within the window
if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i) l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
}
return false return false
} }
func (b *Bits) Update(l *logrus.Logger, i uint64) bool { func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current. // If i is the next number, return true and update current.
if i == b.current+1 { if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter // Report missed packets, we can only understand what was missed after the first window has been gone through
// The very first window can only be tracked as lost once we are on the 2nd window or greater if i > b.length && b.bits[i%b.length] == false {
if b.bits[i%b.length] == false && i > b.length {
b.lostCounter.Inc(1) b.lostCounter.Inc(1)
} }
b.bits[i%b.length] = true b.bits[i%b.length] = true
@@ -61,32 +59,61 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return true return true
} }
// If i is a jump, adjust the window, record lost, update current, and return true // If i packet is greater than current but less than the maximum length of our bitmap,
if i > b.current { // flip everything in between to false and move ahead.
lost := int64(0) if i > b.current && i < b.current+b.length {
// Zero out the bits between the current and the new counter value, limited by the window size, // In between current and i need to be zero'd to allow those packets to come in later
// since the window is shifting for n := b.current + 1; n < i; n++ {
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
lost++
}
b.bits[n%b.length] = false b.bits[n%b.length] = false
} }
// Only record any skipped packets as a result of the window moving further than the window length b.bits[i%b.length] = true
// Any loss within the new window will be accounted for in future calls b.current = i
lost += max(0, int64(i-b.current-b.length)) //l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
return true
}
// If i is greater than the delta between current and the total length of our bitmap,
// just flip everything in the map and move ahead.
if i >= b.current+b.length {
// The current window loss will be accounted for later, only record the jump as loss up until then
lost := maxInt64(0, int64(i-b.current-b.length))
//TODO: explain this
if b.current == 0 {
lost++
}
for n := range b.bits {
// Don't want to count the first window as a loss
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
//if b.bits[n] == false {
// lost++
//}
b.bits[n] = false
}
b.lostCounter.Inc(lost) b.lostCounter.Inc(lost)
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
Debug("Receive window")
}
b.bits[i%b.length] = true b.bits[i%b.length] = true
b.current = i b.current = i
return true return true
} }
// If i is within the current window but below the current counter, // Allow for the 0 packet to come in within the first window
// Check to see if it's a duplicate if i == 0 && b.firstSeen == false && b.current < b.length {
if i > b.current-b.length || i < b.length && b.current < b.length { b.firstSeen = true
if b.current == i || b.bits[i%b.length] == true { b.bits[i%b.length] = true
return true
}
// If i is within the window of current minus length (the total pat window size),
// allow it and flip to true but to NOT change current. We also have to account for the first window
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
if b.current == i {
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window") Debug("Receive window")
@@ -95,8 +122,18 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false return false
} }
if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true b.bits[i%b.length] = true
return true return true
} }
// In all other cases, fail and don't change current. // In all other cases, fail and don't change current.
@@ -110,3 +147,11 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
} }
return false return false
} }
func maxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@@ -15,41 +15,48 @@ func TestBits(t *testing.T) {
assert.Len(t, b.bits, 10) assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work. // This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1)) assert.True(t, b.Check(l, 1))
assert.True(t, b.Update(l, 1)) u := b.Update(l, 1)
assert.True(t, u)
assert.EqualValues(t, 1, b.current) assert.EqualValues(t, 1, b.current)
g := []bool{true, true, false, false, false, false, false, false, false, false} g := []bool{false, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two // Receive two
assert.True(t, b.Check(l, 2)) assert.True(t, b.Check(l, 2))
assert.True(t, b.Update(l, 2)) u = b.Update(l, 2)
assert.True(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
g = []bool{true, true, true, false, false, false, false, false, false, false} g = []bool{false, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two again - it will fail // Receive two again - it will fail
assert.False(t, b.Check(l, 2)) assert.False(t, b.Check(l, 2))
assert.False(t, b.Update(l, 2)) u = b.Update(l, 2)
assert.False(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element // Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15)) assert.True(t, b.Check(l, 15))
assert.True(t, b.Update(l, 15)) u = b.Update(l, 15)
assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false} g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window // Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14)) assert.True(t, b.Check(l, 14))
assert.True(t, b.Update(l, 14)) u = b.Update(l, 14)
assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window // Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(l, 5)) assert.False(t, b.Check(l, 5))
assert.False(t, b.Update(l, 5)) u = b.Update(l, 5)
assert.False(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
@@ -62,29 +69,10 @@ func TestBits(t *testing.T) {
// Walk through a few windows in order // Walk through a few windows in order
b = NewBits(10) b = NewBits(10)
for i := uint64(1); i <= 100; i++ { for i := uint64(0); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i)
} }
assert.False(t, b.Check(l, 1), "Out of window check")
}
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
} }
func TestBitsDupeCounter(t *testing.T) { func TestBitsDupeCounter(t *testing.T) {
@@ -136,7 +124,8 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
assert.False(t, b.Update(l, 0)) assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost //tODO: make sure lostcounter doesn't increase in orderly increment
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
} }
@@ -148,6 +137,8 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20)) assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21)) assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22)) assert.True(t, b.Update(l, 22))
@@ -158,7 +149,7 @@ func TestBitsLostCounter(t *testing.T) {
assert.True(t, b.Update(l, 27)) assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28)) assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29)) assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -167,6 +158,8 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9)) assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets // 10 will set 0 index, 0 was already set, no lost packets
@@ -221,62 +214,6 @@ func TestBitsLostCounter(t *testing.T) {
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
} }
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 5))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 6))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) { func BenchmarkBits(b *testing.B) {
z := NewBits(10) z := NewBits(10)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {

View File

@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
} }
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
nc := &cert.TBSCertificate{
Version: v,
Curve: c.Curve(),
Name: c.Name(),
Networks: c.Networks(),
UnsafeNetworks: c.UnsafeNetworks(),
Groups: c.Groups(),
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
PublicKey: c.PublicKey(),
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pem
}
func X25519Keypair() ([]byte, []byte) { func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32) privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil { if _, err := io.ReadFull(rand.Reader, privkey); err != nil {

View File

@@ -173,8 +173,6 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
var passphrase []byte var passphrase []byte
if !isP11 && *cf.encryption { if !isP11 && *cf.encryption {
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: ")) out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword() passphrase, err = pr.ReadPassword()
@@ -194,7 +192,6 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext") return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
} }
} }
}
var curve cert.Curve var curve cert.Curve
var pub, rawPriv []byte var pub, rawPriv []byte

View File

@@ -171,17 +171,6 @@ func Test_ca(t *testing.T) {
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// test encrypted key with passphrase environment variable
os.Remove(keyF.Name())
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, ca(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// read encrypted key file and verify default params // read encrypted key file and verify default params
rb, _ = os.ReadFile(keyF.Name()) rb, _ = os.ReadFile(keyF.Name())
k, _ := pem.Decode(rb) k, _ := pem.Decode(rb)

View File

@@ -5,28 +5,10 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"runtime/debug"
"strings"
) )
// A version string that can be set with
//
// -ldflags "-X main.Build=SOMEVERSION"
//
// at compile-time.
var Build string var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
type helpError struct { type helpError struct {
s string s string
} }

View File

@@ -43,7 +43,7 @@ type signFlags struct {
func newSignFlags() *signFlags { func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {} sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA") sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
@@ -116,10 +116,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
// naively attempt to decode the private key as though it is not encrypted // naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if errors.Is(err, cert.ErrPrivateKeyEncrypted) { if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
var passphrase []byte
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
// ask for a passphrase until we get one // ask for a passphrase until we get one
var passphrase []byte
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: ")) out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword() passphrase, err = pr.ReadPassword()
@@ -137,7 +135,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if len(passphrase) == 0 { if len(passphrase) == 0 {
return fmt.Errorf("cannot open encrypted ca-key without passphrase") return fmt.Errorf("cannot open encrypted ca-key without passphrase")
} }
}
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
if err != nil { if err != nil {
return fmt.Errorf("error while parsing encrypted ca-key: %s", err) return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
@@ -167,10 +165,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("ca certificate is expired") return fmt.Errorf("ca certificate is expired")
} }
if version == 0 {
version = caCert.Version()
}
// if no duration is given, expire one second before the root expires // if no duration is given, expire one second before the root expires
if *sf.duration <= 0 { if *sf.duration <= 0 {
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
@@ -283,19 +277,21 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
notBefore := time.Now() notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration) notAfter := notBefore.Add(*sf.duration)
switch version { if version == 0 || version == cert.Version1 {
case cert.Version1: // Make sure we at least have an ip
// Make sure we have only one ipv4 address
if len(v4Networks) != 1 { if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
} }
if version == cert.Version1 {
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
if len(v6Networks) > 0 { if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses") return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
} }
if len(v6UnsafeNetworks) > 0 { if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
} }
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
@@ -325,8 +321,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
crts = append(crts, nc) crts = append(crts, nc)
}
case cert.Version2: if version == 0 || version == cert.Version2 {
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
Version: cert.Version2, Version: cert.Version2,
Name: *sf.name, Name: *sf.name,
@@ -354,9 +351,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
crts = append(crts, nc) crts = append(crts, nc)
default:
// this should be unreachable
return fmt.Errorf("invalid version: %d", version)
} }
if !isP11 && *sf.inPubPath == "" { if !isP11 && *sf.inPubPath == "" {

View File

@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
" -unsafe-networks string\n"+ " -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+ " -version uint\n"+
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n", " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
ob.String(), ob.String(),
) )
} }
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -379,15 +379,6 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// test with the proper password in the environment
os.Remove(crtF.Name())
os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the wrong password // test with the wrong password
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
@@ -398,17 +389,6 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// test with the wrong password in environment
ob.Reset()
eb.Reset()
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the user not entering a password // test with the user not entering a password
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()

View File

@@ -4,8 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
@@ -20,17 +18,6 @@ import (
// at compile-time. // at compile-time.
var Build string var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() { func main() {
serviceFlag := flag.String("service", "", "Control the system service.") serviceFlag := flag.String("service", "", "Control the system service.")
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")

View File

@@ -4,8 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
@@ -20,17 +18,6 @@ import (
// at compile-time. // at compile-time.
var Build string var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() { func main() {
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config") configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")

View File

@@ -17,7 +17,7 @@ import (
"dario.cat/mergo" "dario.cat/mergo"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
type C struct { type C struct {

View File

@@ -10,7 +10,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {

View File

@@ -354,6 +354,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if mainHostInfo { if mainHostInfo {
decision = tryRehandshake decision = tryRehandshake
} else { } else {
if cm.shouldSwapPrimary(hostinfo) { if cm.shouldSwapPrimary(hostinfo) {
decision = swapPrimary decision = swapPrimary
@@ -460,10 +461,6 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
} }
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
if crt == nil {
//my cert was reloaded away. We should definitely swap from this tunnel
return true
}
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down. // settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
@@ -478,34 +475,31 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
cm.hostMap.Unlock() cm.hostMap.Unlock()
} }
// isInvalidCertificate decides if we should destroy a tunnel. // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid. // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true. // check and return true.
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
remoteCert := hostinfo.GetCert() remoteCert := hostinfo.GetCert()
if remoteCert == nil { if remoteCert == nil {
return false //don't tear down tunnels for handshakes in progress return false
} }
caPool := cm.intf.pki.GetCAPool() caPool := cm.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert) err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil { if err == nil {
return false //cert is still valid! yay! return false
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed }
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
// Block listed certificates should always be disconnected // Block listed certificates should always be disconnected
hostinfo.logger(cm.l).WithError(err). return false
WithField("fingerprint", remoteCert.Fingerprint). }
Info("Remote certificate is blocked, tearing down the tunnel")
return true
} else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err). hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint). WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel") Info("Remote certificate is no longer valid, tearing down the tunnel")
return true return true
} else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
return false
}
} }
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
@@ -536,45 +530,15 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState() cs := cm.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert curCrt := hostinfo.ConnectionState.myCert
curCrtVersion := curCrt.Version() myCrt := cs.getCertificate(curCrt.Version())
myCrt := cs.getCertificate(curCrtVersion) if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
if myCrt == nil { // The current tunnel is using the latest certificate and version, no need to rehandshake.
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return return
} }
peerCrt := hostinfo.ConnectionState.peerCert
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version lower than peer, attempting to correct").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
return
}
}
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current"). WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote") Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "current cert version < pki.initiatingVersion").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
} }

View File

@@ -50,6 +50,11 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
} }
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{ hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs, CipherSuite: ncs,
Random: rand.Reader, Random: rand.Reader,
@@ -69,7 +74,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
ci := &ConnectionState{ ci := &ConnectionState{
H: hs, H: hs,
initiator: initiator, initiator: initiator,
window: NewBits(ReplayWindow), window: b,
myCert: crt, myCert: crt,
} }
// always start the counter from 2, as packet 1 and packet 2 are handshake packets. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.

View File

@@ -174,10 +174,6 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap return c.f.hostMap
} }
func (c *Control) GetF() *Interface {
return c.f
}
func (c *Control) GetCertState() *CertState { func (c *Control) GetCertState() *CertState {
return c.f.pki.getCertState() return c.f.pki.getCertState()
} }

View File

@@ -20,17 +20,16 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
@@ -39,9 +38,6 @@ func BenchmarkHotPath(b *testing.B) {
r := router.NewR(b, myControl, theirControl) r := router.NewR(b, myControl, theirControl)
r.CancelFlowLogs() r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl) _ = r.RouteForAllUntilTxTun(theirControl)
@@ -51,39 +47,6 @@ func BenchmarkHotPath(b *testing.B) {
theirControl.Stop() theirControl.Stop()
} }
func BenchmarkHotPathRelay(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(b, myControl, relayControl, theirControl)
r.CancelFlowLogs()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
}
func TestGoodHandshake(t *testing.T) { func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
@@ -134,41 +97,6 @@ func TestGoodHandshake(t *testing.T) {
theirControl.Stop() theirControl.Stop()
} }
func TestGoodHandshakeNoOverlap(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
empty := []byte{}
t.Log("do something to cause a handshake")
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
myControl.WaitForType(header.Test, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestWrongResponderHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -536,35 +464,6 @@ func TestRelays(t *testing.T) {
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
} }
func TestRelaysDontCareAboutIps(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
}
func TestReestablishRelays(t *testing.T) { func TestReestablishRelays(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
@@ -1328,109 +1227,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
o := m{
"static_host_map": m{
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
},
"lighthouse": m{
"hosts": []string{lhVpnIpNet[0].Addr().String()},
"local_allow_list": m{
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
"10.0.0.0/24": true,
"::/0": false,
},
},
}
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, lhControl, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
lhControl.Start()
myControl.Start()
theirControl.Start()
t.Log("Stand up an ipv6 tunnel between me and them")
assert.True(t, myVpnIpNet[1].Addr().Is6())
assert.True(t, theirVpnIpNet[1].Addr().Is6())
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
lhControl.Stop()
myControl.Stop()
theirControl.Stop()
}
func TestGoodHandshakeUnsafeDest(t *testing.T) {
unsafePrefix := "192.168.6.0/24"
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
myCfg := m{
"tun": m{
"unsafe_routes": []m{route},
},
}
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
t.Logf("my config %v", myConfig)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
spookyDest := netip.MustParseAddr("192.168.6.4")
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet so that we can play with it")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
myControl.InjectUDPPacket(badPacket)
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see my cached packet come through")
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
//reply
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
//wait for reply
theirControl.WaitForType(1, 0, myControl)
theirCachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -22,14 +22,15 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "gopkg.in/yaml.v3"
"go.yaml.in/yaml/v3"
) )
type m = map[string]any type m = map[string]any
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") { for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -55,54 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
budpIp[3] = 239 budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
} }
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides) _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
}
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
}
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
firewallInbound := []m{{
"proto": "any",
"port": "any",
"host": "any",
}}
var unsafeNetworks []netip.Prefix
if sUnsafeNetworks != "" {
firewallInbound = []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "0.0.0.0/0",
}}
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
unsafeNetworks = append(unsafeNetworks, x)
}
}
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
caB, err := caCrt.MarshalPEM() caB, err := caCrt.MarshalPEM()
if err != nil { if err != nil {
@@ -122,7 +76,11 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": "any", "port": "any",
"host": "any", "host": "any",
}}, }},
"inbound": firewallInbound, "inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
}, },
//"handshakes": m{ //"handshakes": m{
// "try_interval": "1s", // "try_interval": "1s",
@@ -171,109 +129,6 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
return control, vpnNetworks, udpAddr, c return control, vpnNetworks, udpAddr, c
} }
// newServer creates a nebula instance with fewer assumptions
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
vpnNetworks := certs[len(certs)-1].Networks()
var udpAddr netip.AddrPort
if vpnNetworks[0].Addr().Is4() {
budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else {
budpIp := vpnNetworks[0].Addr().As16()
// beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
caStr := ""
for _, ca := range caCrt {
x, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
caStr += string(x)
}
certStr := ""
for _, c := range certs {
x, err := c.MarshalPEM()
if err != nil {
panic(err)
}
certStr += string(x)
}
mc := m{
"pki": m{
"ca": caStr,
"cert": certStr,
"key": string(key),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": m{
"host": udpAddr.Addr().String(),
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
"level": l.Level.String(),
},
"timers": m{
"pending_deletion_interval": 2,
"connection_alive_interval": 2,
},
}
if overrides != nil {
final := m{}
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
c := config.NewC(l)
cStr := string(cb)
c.LoadString(cStr)
control, err := nebula.Main(c, false, "e2e-test", l, nil)
if err != nil {
panic(err)
}
return control, vpnNetworks, udpAddr, c
}
type doneCb func() type doneCb func()
func deadline(t *testing.T, seconds time.Duration) doneCb { func deadline(t *testing.T, seconds time.Duration) doneCb {
@@ -292,7 +147,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
} }
} }
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me // Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA) bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -304,14 +159,14 @@ func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *n
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
} }
func assertHostInfoPair(t testing.TB, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) { func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
// Get both host infos // Get both host infos
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct // Check that both vpn and real addr are correct
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
@@ -325,7 +180,7 @@ func assertHostInfoPair(t testing.TB, addrA, addrB netip.AddrPort, vpnNetsA, vpn
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
} }
func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() { if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else { } else {
@@ -333,7 +188,7 @@ func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
} }
} }
func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found") assert.NotNil(t, v6, "No ipv6 data found")
@@ -352,7 +207,7 @@ func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect") assert.Equal(t, expected, data.Payload(), "Data was incorrect")
} }
func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found") assert.NotNil(t, v4, "No ipv4 data found")

View File

@@ -4,16 +4,12 @@
package e2e package e2e
import ( import (
"fmt"
"net/netip"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
) )
func TestDropInactiveTunnels(t *testing.T) { func TestDropInactiveTunnels(t *testing.T) {
@@ -59,309 +55,3 @@ func TestDropInactiveTunnels(t *testing.T) {
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestCertUpgrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCert2Pem),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": myC.Settings["firewall"],
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v2-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if c == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
r.Logf("version %d", version)
if version == cert.Version2 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*10 {
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertDowngrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCertPem),
"key": string(myPrivKey),
},
"firewall": myC.Settings["firewall"],
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v1-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == cert.Version1 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertMismatchCorrection(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == theirVersion {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("wtf")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCrossStackRelaysWork(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
//myVpnV4 := myVpnIpNet[0]
myVpnV6 := myVpnIpNet[1]
relayVpnV4 := relayVpnIpNet[0]
relayVpnV6 := relayVpnIpNet[1]
theirVpnV6 := theirVpnIpNet[0]
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//t.Log("finish up")
//myControl.Stop()
//theirControl.Stop()
//relayControl.Stop()
}

View File

@@ -383,9 +383,8 @@ firewall:
# host: `any` or a literal hostname, ie `test-host` # host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group` # group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
# This can be used to filter destinations when using unsafe_routes.
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true. # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case. # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
# ca_name: An issuing CA name # ca_name: An issuing CA name

View File

@@ -8,7 +8,6 @@ import (
"hash/fnv" "hash/fnv"
"net/netip" "net/netip"
"reflect" "reflect"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -23,7 +22,7 @@ import (
) )
type FirewallInterface interface { type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
} }
type conn struct { type conn struct {
@@ -248,11 +247,22 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
} }
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
if localIp.IsValid() {
lIp = localIp.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload. // We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf( ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
) )
f.rules += ruleString + "\n" f.rules += ruleString + "\n"
@@ -260,7 +270,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added") Info("Firewall rule added")
var ( var (
@@ -287,7 +297,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto) return fmt.Errorf("unknown protocol %v", proto)
} }
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
} }
// GetRuleHash returns a hash representation of all inbound and outbound rules // GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -327,6 +337,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
} }
for i, t := range rs { for i, t := range rs {
var groups []string
r, err := convertRule(l, t, table, i) r, err := convertRule(l, t, table, i)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err) return fmt.Errorf("%s rule #%v; %s", table, i, err)
@@ -336,10 +347,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
} }
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" { if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i) return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
} }
if len(r.Groups) > 0 {
groups = r.Groups
}
if r.Group != "" {
// Check if we have both groups and group provided in the rule config
if len(groups) > 0 {
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
}
groups = []string{r.Group}
}
var sPort, errPort string var sPort, errPort string
if r.Code != "" { if r.Code != "" {
errPort = "code" errPort = "code"
@@ -368,25 +392,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
} }
if r.Cidr != "" && r.Cidr != "any" { var cidr netip.Prefix
_, err = netip.ParsePrefix(r.Cidr) if r.Cidr != "" {
cidr, err = netip.ParsePrefix(r.Cidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
} }
} }
if r.LocalCidr != "" && r.LocalCidr != "any" { var localCidr netip.Prefix
_, err = netip.ParsePrefix(r.LocalCidr) if r.LocalCidr != "" {
localCidr, err = netip.ParsePrefix(r.LocalCidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
} }
} }
if warning := r.sanity(); warning != nil { err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
l.Warnf("%s rule #%v; %s", table, i, warning)
}
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err) return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
} }
@@ -395,10 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return nil return nil
} }
var ErrUnknownNetworkType = errors.New("unknown network type") var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
var ErrPeerRejected = errors.New("remote address is not within a network that we handle") var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table") var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It // Drop returns an error if the packet should be dropped, explaining why. It
@@ -409,31 +429,18 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
return nil return nil
} }
// Make sure remote address matches nebula certificate, and determine how to treat it // Make sure remote address matches nebula certificate
if h.networks == nil { if h.networks != nil {
// Simple case: Certificate has one address and no unsafe networks if !h.networks.Contains(fp.RemoteAddr) {
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
} else { } else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr) // Simple case: Certificate has one address and no unsafe networks
if !ok { if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
switch nwType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
@@ -633,7 +640,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
return false return false
} }
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port was lower than end port") return fmt.Errorf("start port was lower than end port")
} }
@@ -646,7 +653,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
} }
} }
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil { if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
return err return err
} }
} }
@@ -677,7 +684,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
return fp[firewall.PortAny].match(p, c, caPool) return fp[firewall.PortAny].match(p, c, caPool)
} }
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error { func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
fr := func() *FirewallRule { fr := func() *FirewallRule {
return &FirewallRule{ return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR), Hosts: make(map[string]*firewallLocalCIDR),
@@ -691,14 +698,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
fc.Any = fr() fc.Any = fr()
} }
return fc.Any.addRule(f, groups, host, cidr, localCidr) return fc.Any.addRule(f, groups, host, ip, localIp)
} }
if caSha != "" { if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok { if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr() fc.CAShas[caSha] = fr()
} }
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr) err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@@ -708,7 +715,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
if _, ok := fc.CANames[caName]; !ok { if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr() fc.CANames[caName] = fr()
} }
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr) err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@@ -740,24 +747,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
return fc.CANames[s.Certificate.Name()].match(p, c) return fc.CANames[s.Certificate.Name()].match(p, c)
} }
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error { func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR { flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{ return &firewallLocalCIDR{
LocalCIDR: new(bart.Lite), LocalCIDR: new(bart.Lite),
} }
} }
if fr.isAny(groups, host, cidr) { if fr.isAny(groups, host, ip) {
if fr.Any == nil { if fr.Any == nil {
fr.Any = flc() fr.Any = flc()
} }
return fr.Any.addRule(f, localCidr) return fr.Any.addRule(f, localCIDR)
} }
if len(groups) > 0 { if len(groups) > 0 {
nlc := flc() nlc := flc()
err := nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
@@ -773,34 +780,30 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localC
if nlc == nil { if nlc == nil {
nlc = flc() nlc = flc()
} }
err := nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
fr.Hosts[host] = nlc fr.Hosts[host] = nlc
} }
if cidr != "" { if ip.IsValid() {
c, err := netip.ParsePrefix(cidr) nlc, _ := fr.CIDR.Get(ip)
if err != nil {
return err
}
nlc, _ := fr.CIDR.Get(c)
if nlc == nil { if nlc == nil {
nlc = flc() nlc = flc()
} }
err = nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
fr.CIDR.Insert(c, nlc) fr.CIDR.Insert(ip, nlc)
} }
return nil return nil
} }
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool { func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
if len(groups) == 0 && host == "" && cidr == "" { if len(groups) == 0 && host == "" && !ip.IsValid() {
return true return true
} }
@@ -814,7 +817,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
return true return true
} }
if cidr == "any" { if ip.IsValid() && ip.Bits() == 0 {
return true return true
} }
@@ -866,13 +869,8 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
return false return false
} }
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if localCidr == "any" { if !localIp.IsValid() {
flc.Any = true
return nil
}
if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
flc.Any = true flc.Any = true
return nil return nil
@@ -883,13 +881,12 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
} }
return nil return nil
} else if localIp.Bits() == 0 {
flc.Any = true
return nil
} }
c, err := netip.ParsePrefix(localCidr) flc.LocalCIDR.Insert(localIp)
if err != nil {
return err
}
flc.LocalCIDR.Insert(c)
return nil return nil
} }
@@ -910,6 +907,7 @@ type rule struct {
Code string Code string
Proto string Proto string
Host string Host string
Group string
Groups []string Groups []string
Cidr string Cidr string
LocalCidr string LocalCidr string
@@ -951,8 +949,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
m["group"] = v[0] m["group"] = v[0]
} }
r.Group = toString("group", m)
singleGroup := toString("group", m)
if rg, ok := m["groups"]; ok { if rg, ok := m["groups"]; ok {
switch reflect.TypeOf(rg).Kind() { switch reflect.TypeOf(rg).Kind() {
@@ -969,60 +966,9 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
} }
} }
//flatten group vs groups
if singleGroup != "" {
// Check if we have both groups and group provided in the rule config
if len(r.Groups) > 0 {
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
}
r.Groups = []string{singleGroup}
}
return r, nil return r, nil
} }
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
func (r *rule) sanity() error {
//port, proto, local_cidr are AND, no need to check here
//ca_sha and ca_name don't have a wildcard value, no need to check here
groupsEmpty := len(r.Groups) == 0
hostEmpty := r.Host == ""
cidrEmpty := r.Cidr == ""
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
return nil //no content!
}
groupsHasAny := slices.Contains(r.Groups, "any")
if groupsHasAny && len(r.Groups) > 1 {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
}
if r.Host == "any" {
if !groupsEmpty {
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
}
if !cidrEmpty {
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
}
}
if groupsHasAny {
if !hostEmpty && r.Host != "any" {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
}
if !cidrEmpty {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
}
}
//todo alert on cidr-any
return nil
}
func parsePort(s string) (startPort, endPort int32, err error) { func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" { if s == "any" {
startPort = firewall.PortAny startPort = firewall.PortAny

View File

@@ -8,8 +8,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@@ -73,114 +71,85 @@ func TestFirewall_AddRule(t *testing.T) {
ti6, err := netip.ParsePrefix("fd12::34/128") ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0") anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0") anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -205,10 +174,10 @@ func TestFirewall_Drop(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
} }
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -227,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -257,9 +226,6 @@ func TestFirewall_DropV6(t *testing.T) {
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"), LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"), RemoteAddr: netip.MustParseAddr("fd12::34"),
@@ -284,10 +250,10 @@ func TestFirewall_DropV6(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
} }
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -306,28 +272,28 @@ func TestFirewall_DropV6(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -338,12 +304,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
pfix := netip.MustParsePrefix("172.1.1.1/32") pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128") pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
@@ -487,8 +453,6 @@ func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -514,7 +478,7 @@ func TestFirewall_Drop2(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
c1 := cert.CachedCertificate{ c1 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -529,10 +493,10 @@ func TestFirewall_Drop2(t *testing.T) {
peerCert: &c1, peerCert: &c1,
}, },
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
@@ -546,8 +510,6 @@ func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -579,7 +541,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
c2 := cert.CachedCertificate{ c2 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -594,7 +556,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h2.buildNetworks(myVpnNetworksTable, c2.Certificate) h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
c3 := cert.CachedCertificate{ c3 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -609,11 +571,11 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h3.buildNetworks(myVpnNetworksTable, c3.Certificate) h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
@@ -627,7 +589,7 @@ func TestFirewall_Drop3(t *testing.T) {
// Test a remote address match // Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
@@ -635,8 +597,6 @@ func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"), LocalAddr: netip.MustParseAddr("fd12::34"),
@@ -660,12 +620,12 @@ func TestFirewall_Drop3V6(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match // Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool() cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -673,8 +633,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -701,10 +659,10 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -717,7 +675,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -726,7 +684,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -738,8 +696,6 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
c := cert.CachedCertificate{ c := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -761,11 +717,11 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()}, vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
@@ -959,28 +915,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr // Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8") cidr := netip.MustParsePrefix("10.0.0.0/8")
@@ -988,14 +944,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6 // Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8") cidr6 := netip.MustParsePrefix("fd00::/8")
@@ -1003,75 +959,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6 // Test adding rule with local_cidr ipv6
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error // Test Add error
conf = config.NewC(l) conf = config.NewC(l)
@@ -1094,7 +1024,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err := convertRule(l, c, "test", 1) r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups) assert.Equal(t, "group1", r.Group)
// Ensure group array of > 1 is errord // Ensure group array of > 1 is errord
ob.Reset() ob.Reset()
@@ -1114,228 +1044,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err = convertRule(l, c, "test", 1) r, err = convertRule(l, c, "test", 1)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups) assert.Equal(t, "group1", r.Group)
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"host": "bob"},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range noWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
}
yesWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
c["host"] = "any"
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
//reset the list
yesWarningPlease = []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
r.Groups = append(r.Groups, "any")
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
}
type testcase struct {
h *HostInfo
p firewall.Packet
c cert.Certificate
err error
}
func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
}
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
c1 := dummyCert{
name: "host1",
networks: theirPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c1,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
}
for i := range theirPrefixes {
h.vpnAddrs[i] = theirPrefixes[i].Addr()
}
h.buildNetworks(setup.myVpnNetworksTable, &c1)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: theirPrefixes[0].Addr(),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
return testcase{
h: &h,
p: p,
c: &c1,
err: err,
}
}
type testsetup struct {
c dummyCert
myVpnNetworksTable *bart.Lite
fw *Firewall
}
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{
name: "me",
networks: myPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
return newSetupFromCert(t, l, c)
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{
c: c,
fw: fw,
myVpnNetworksTable: myVpnNetworksTable,
}
}
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
t.Run("allow inbound all matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
tc.Test(t, setup.fw)
})
t.Run("allow inbound local matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
tc.Test(t, setup.fw)
})
t.Run("block inbound remote mismatched", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("Block a vpn peer packet", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
tc.Test(t, setup.fw)
})
twoPrefixes := []netip.Prefix{
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
}
t.Run("allow inbound one matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, twoPrefixes...)
tc.Test(t, setup.fw)
})
t.Run("block inbound multimismatch", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
t.Parallel()
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
tc := buildTestCase(setup2, nil, twoPrefixes...)
tc.p.RemoteAddr = twoPrefixes[1].Addr()
tc.Test(t, setup2.fw)
})
t.Run("allow inbound unsafe route", func(t *testing.T) {
t.Parallel()
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
c := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
unsafeNetworks: []netip.Prefix{unsafePrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
unsafeSetup := newSetupFromCert(t, l, c)
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass
})
} }
type addRuleCall struct { type addRuleCall struct {
@@ -1345,8 +1054,8 @@ type addRuleCall struct {
endPort int32 endPort int32
groups []string groups []string
host string host string
ip string ip netip.Prefix
localIp string localIp netip.Prefix
caName string caName string
caSha string caSha string
} }
@@ -1356,7 +1065,7 @@ type mockFirewall struct {
nextCallReturn error nextCallReturn error
} }
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
mf.lastCall = addRuleCall{ mf.lastCall = addRuleCall{
incoming: incoming, incoming: incoming,
proto: proto, proto: proto,

21
go.mod
View File

@@ -8,7 +8,7 @@ require (
github.com/armon/go-radix v1.0.0 github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0 github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.26.0 github.com/gaissmai/bart v0.25.0
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4 github.com/kardianos/service v1.2.4
@@ -22,19 +22,18 @@ require (
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.43.0
golang.org/x/crypto v0.45.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.47.0 golang.org/x/net v0.45.0
golang.org/x/sync v0.18.0 golang.org/x/sync v0.17.0
golang.org/x/sys v0.38.0 golang.org/x/sys v0.37.0
golang.org/x/term v0.37.0 golang.org/x/term v0.36.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.10 google.golang.org/protobuf v1.36.8
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
) )
require ( require (
@@ -50,6 +49,6 @@ require (
github.com/vishvananda/netns v0.0.5 // indirect github.com/vishvananda/netns v0.0.5 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.24.0 // indirect golang.org/x/mod v0.24.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.7.0 // indirect
golang.org/x/tools v0.33.0 // indirect golang.org/x/tools v0.33.0 // indirect
) )

42
go.sum
View File

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0= github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -155,15 +155,13 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +180,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -191,8 +189,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -209,16 +207,16 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
@@ -232,8 +230,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
@@ -244,8 +242,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -259,5 +257,5 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=

View File

@@ -2,6 +2,7 @@ package nebula
import ( import (
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
@@ -22,19 +23,15 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return false return false
} }
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState() cs := f.pki.getCertState()
v := cs.initiatingVersion v := cs.initiatingVersion
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
// If we're connecting to a v6 address we should encourage use of a V2 cert
for _, a := range hh.hostinfo.vpnAddrs { for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() { if a.Is6() {
v = cert.Version2 v = cert.Version2
break break
} }
} }
}
crt := cs.getCertificate(v) crt := cs.getCertificate(v)
if crt == nil { if crt == nil {
@@ -51,7 +48,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v). WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available") Error("Unable to handshake with host because no certificate handshake bytes is available")
return false
} }
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
@@ -107,7 +103,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion). WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available") Error("Unable to handshake with host because no certificate is available")
return
} }
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
@@ -148,8 +143,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil { if err != nil {
fp, fperr := rc.Fingerprint() fp, err := rc.Fingerprint()
if fperr != nil { if err != nil {
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
@@ -168,19 +163,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if remoteCert.Certificate.Version() != ci.myCert.Version() { if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us // We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) rc := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil { if rc == nil {
if f.l.Level >= logrus.DebugLevel { f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithFields(m{ WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
"udpAddr": addr, Info("Unable to handshake with host due to missing certificate version")
"handshake": m{"stage": 1, "style": "ix_psk0"}, return
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
} }
} else {
// Record the certificate we are actually using // Record the certificate we are actually using
ci.myCert = myCertOtherVersion ci.myCert = rc
}
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
@@ -191,17 +183,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version() certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
vpnNetworks := remoteCert.Certificate.Networks()
anyVpnAddrsInCommon := false for _, network := range remoteCert.Certificate.Networks() {
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddr := network.Addr()
for i, network := range vpnNetworks { if f.myVpnAddrsTable.Contains(vpnAddr) {
if f.myVpnAddrsTable.Contains(network.Addr()) { f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -209,10 +201,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return return
} }
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) { // vpnAddrs outside our vpn networks are of no use to us, filter them out
anyVpnAddrsInCommon = true if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue
} }
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return
} }
if addr.IsValid() { if addr.IsValid() {
@@ -249,30 +255,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}, },
} }
msgRxL := f.l.WithFields(m{ f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
"vpnAddrs": vpnAddrs, WithField("certName", certName).
"udpAddr": addr, WithField("certVersion", certVersion).
"certName": certName, WithField("fingerprint", fingerprint).
"certVersion": certVersion, WithField("issuer", issuer).
"fingerprint": fingerprint, WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
"issuer": issuer, WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
"initiatorIndex": hs.Details.InitiatorIndex, Info("Handshake message received")
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil { if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available") Error("Unable to handshake with host because no certificate handshake bytes is available")
return return
} }
@@ -330,7 +332,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil { if err != nil {
@@ -571,22 +573,31 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
correctHostResponded := false var vpnAddrs []netip.Addr
anyVpnAddrsInCommon := false var filteredNetworks []netip.Prefix
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for _, network := range vpnNetworks {
for i, network := range vpnNetworks { // vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddrs[i] = network.Addr() vpnAddr := network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) { if !f.myVpnNetworksTable.Contains(vpnAddr) {
anyVpnAddrsInCommon = true continue
} }
if hostinfo.vpnAddrs[0] == network.Addr() {
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not? filteredNetworks = append(filteredNetworks, network)
correctHostResponded = true vpnAddrs = append(vpnAddrs, vpnAddr)
} }
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
} }
// Ensure the right host responded // Ensure the right host responded
if !correctHostResponded { if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr). WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
@@ -598,7 +609,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip // Create a new hostinfo/handshake for the intended vpn ip
//TODO is hostinfo.vpnAddrs[0] always the address to use?
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
@@ -625,7 +635,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -633,17 +643,12 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration). WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore)) WithField("sentCachedPackets", len(hh.packetStore)).
if anyVpnAddrsInCommon { Info("Handshake message received")
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
// Build up the radix for the firewall if we have subnets in the cert // Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)

View File

@@ -70,7 +70,6 @@ type HandshakeHostInfo struct {
startTime time.Time // Time that we first started trying with this handshake startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready ready bool // Is the handshake ready
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
counter int64 // How many attempts have we made so far counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
@@ -269,12 +268,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to // Don't relay to myself
if relay == vpnIp { if relay == vpnIp {
continue continue
} }
// Don't relay to myself // Don't relay through the host I'm trying to connect to
if hm.f.myVpnAddrsTable.Contains(relay) { if hm.f.myVpnAddrsTable.Contains(relay) {
continue continue
} }

View File

@@ -212,18 +212,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.relayForByIdx[idx] = r rs.relayForByIdx[idx] = r
} }
type NetworkType uint8
const (
NetworkTypeUnknown NetworkType = iota
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
NetworkTypeVPN
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
)
type HostInfo struct { type HostInfo struct {
remote netip.AddrPort remote netip.AddrPort
remotes *RemoteList remotes *RemoteList
@@ -237,8 +225,8 @@ type HostInfo struct {
// vpn networks but were removed because they are not usable // vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr vpnAddrs []netip.Addr
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. // networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[NetworkType] networks *bart.Lite
relayState RelayState relayState RelayState
// HandshakePacket records the packets used to create this hostinfo // HandshakePacket records the packets used to create this hostinfo
@@ -742,26 +730,20 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false return false
} }
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up. func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) { if len(networks) == 1 && len(unsafeNetworks) == 0 {
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { // Simple case, no CIDRTree needed
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) { return
return // Simple case, no BART needed
}
} }
i.networks = new(bart.Table[NetworkType]) i.networks = new(bart.Lite)
for _, network := range c.Networks() { for _, network := range networks {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
if myVpnNetworksTable.Contains(network.Addr()) { i.networks.Insert(nprefix)
i.networks.Insert(nprefix, NetworkTypeVPN)
} else {
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
}
} }
for _, network := range c.UnsafeNetworks() { for _, network := range unsafeNetworks {
i.networks.Insert(network, NetworkTypeUnsafe) i.networks.Insert(network)
} }
} }

View File

@@ -33,8 +33,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
// routes packets from the Nebula addr to the Nebula addr through the Nebula // routes packets from the Nebula addr to the Nebula addr through the Nebula
// TUN device. // TUN device.
if immediatelyForwardToSelf { if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet) if err := f.writeTun(q, packet); err != nil {
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun") f.l.WithError(err).Error("Failed to forward to tun")
} }
} }
@@ -91,8 +90,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
return return
} }
_, err := f.readers[q].Write(out) if err := f.writeTun(q, out); err != nil {
if err != nil {
f.l.WithError(err).Error("Failed to write to tun") f.l.WithError(err).Error("Failed to write to tun")
} }
} }
@@ -120,10 +118,9 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
} }
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
// it does not check if it is within our vpn networks!
func (f *Interface) Handshake(vpnAddr netip.Addr) { func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.handshakeManager.GetOrHandshake(vpnAddr, nil) f.getOrHandshakeNoRouting(vpnAddr, nil)
} }
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
@@ -139,6 +136,7 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
destinationAddr := fwPacket.RemoteAddr destinationAddr := fwPacket.RemoteAddr
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
@@ -231,10 +229,9 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
} }
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr. // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })

View File

@@ -47,6 +47,7 @@ type InterfaceConfig struct {
reQueryWait time.Duration reQueryWait time.Duration
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
batchSize int
l *logrus.Logger l *logrus.Logger
} }
@@ -84,6 +85,7 @@ type Interface struct {
version string version string
conntrackCacheTimeout time.Duration conntrackCacheTimeout time.Duration
batchSize int
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
@@ -110,6 +112,16 @@ type EncWriter interface {
GetCertState() *CertState GetCertState() *CertState
} }
// BatchReader is an interface for readers that support vectorized packet reading
type BatchReader interface {
BatchRead(buffers [][]byte, sizes []int) (int, error)
}
// BatchWriter is an interface for writers that support vectorized packet writing
type BatchWriter interface {
BatchWrite([][]byte) (int, error)
}
type sendRecvErrorConfig uint8 type sendRecvErrorConfig uint8
const ( const (
@@ -186,6 +198,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
relayManager: c.relayManager, relayManager: c.relayManager,
connectionManager: c.connectionManager, connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
batchSize: c.batchSize,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics, messageMetrics: c.MessageMetrics,
@@ -222,13 +235,6 @@ func (f *Interface) activate() {
WithField("boringcrypto", boringEnabled()). WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active") Info("Nebula interface is active")
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
f.routines = 1
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
}
}
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues // Prepare n tun queues
@@ -276,7 +282,7 @@ func (f *Interface) listenOut(i int) {
plaintext := make([]byte, udp.MTU) plaintext := make([]byte, udp.MTU)
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
@@ -286,6 +292,16 @@ func (f *Interface) listenOut(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread() runtime.LockOSThread()
// Check if reader supports batch operations
if batchReader, ok := reader.(BatchReader); ok {
err := f.listenInBatch(batchReader, i)
if err != nil {
f.l.WithError(err).Error("Fatal error in batch packet reader, exiting goroutine")
}
return
}
// Fall back to single-packet mode
packet := make([]byte, mtu) packet := make([]byte, mtu)
out := make([]byte, mtu) out := make([]byte, mtu)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
@@ -300,15 +316,85 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
return return
} }
f.l.WithError(err).Error("Error while reading outbound packet") f.l.WithError(err).Error("Fatal error while reading outbound packet, exiting goroutine")
// This only seems to happen when something fatal happens to the fd, so exit. return
os.Exit(2)
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
} }
} }
// listenInBatch handles vectorized packet reading for improved performance
func (f *Interface) listenInBatch(reader BatchReader, i int) error {
// Allocate per-packet state and buffers for batch reading
batchSize := f.batchSize
if batchSize <= 0 {
batchSize = 64 // Fallback to default if not configured
}
fwPackets := make([]*firewall.Packet, batchSize)
outBuffers := make([][]byte, batchSize)
nbBuffers := make([][]byte, batchSize)
packets := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for j := 0; j < batchSize; j++ {
fwPackets[j] = &firewall.Packet{}
outBuffers[j] = make([]byte, mtu)
nbBuffers[j] = make([]byte, 12)
packets[j] = make([]byte, mtu)
}
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.BatchRead(packets, sizes)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return nil
}
return fmt.Errorf("error while batch reading outbound packets: %w", err)
}
// Process each packet in the batch
cache := conntrackCache.Get(f.l)
for idx := 0; idx < n; idx++ {
if sizes[idx] > 0 {
// Use modulo to reuse fw packet state if batch is larger than our pre-allocated state
stateIdx := idx % len(fwPackets)
f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache)
}
}
}
}
// writeTunBatch attempts to write multiple packets to the TUN device using batch operations if supported
func (f *Interface) writeTunBatch(q int, packets [][]byte) error {
if len(packets) == 0 {
return nil
}
// Check if the reader/writer supports batch operations
if batchWriter, ok := f.readers[q].(BatchWriter); ok {
_, err := batchWriter.BatchWrite(packets)
return err
}
// Fall back to writing packets individually
for _, packet := range packets {
if _, err := f.readers[q].Write(packet); err != nil {
return err
}
}
return nil
}
// writeTun writes a single packet to the TUN device
func (f *Interface) writeTun(q int, packet []byte) error {
_, err := f.readers[q].Write(packet)
return err
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError) c.RegisterReloadCallback(f.reloadSendRecvError)

View File

@@ -360,8 +360,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
} }
if !lh.myVpnNetworksTable.Contains(addr) { if !lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
} }
out[i] = addr out[i] = addr
} }
@@ -432,8 +431,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
} }
if !lh.myVpnNetworksTable.Contains(vpnAddr) { if !lh.myVpnNetworksTable.Contains(vpnAddr) {
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
} }
vals, ok := v.([]any) vals, ok := v.([]any)
@@ -1339,19 +1337,12 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
} }
} }
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts { for _, a := range n.Details.V4AddrPorts {
b := protoV4AddrPortToNetAddrPort(a) punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
} }
for _, a := range n.Details.V6AddrPorts { for _, a := range n.Details.V6AddrPorts {
b := protoV6AddrPortToNetAddrPort(a) punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
} }
// This sends a nebula test packet to the host trying to contact us. In the case // This sends a nebula test packet to the host trying to contact us. In the case

View File

@@ -14,7 +14,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func TestOldIPv4Only(t *testing.T) { func TestOldIPv4Only(t *testing.T) {

27
main.go
View File

@@ -5,8 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime/debug"
"strings"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -15,7 +13,7 @@ import (
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
type m = map[string]any type m = map[string]any
@@ -29,10 +27,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
}() }()
if buildVersion == "" {
buildVersion = moduleVersion()
}
l := logger l := logger
l.Formatter = &logrus.TextFormatter{ l.Formatter = &logrus.TextFormatter{
FullTimestamp: true, FullTimestamp: true,
@@ -81,8 +75,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c) sshStart, err = configSSH(l, ssh, c)
if err != nil { if err != nil {
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
sshStart = nil
} }
} }
@@ -249,6 +242,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
relayManager: NewRelayManager(ctx, l, hostMap, c), relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy, punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
batchSize: c.GetInt("tun.batch_size", 64),
l: l, l: l,
} }
@@ -302,18 +296,3 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
connManager.Start, connManager.Start,
}, nil }, nil
} }
func moduleVersion() string {
info, ok := debug.ReadBuildInfo()
if !ok {
return ""
}
for _, dep := range info.Deps {
if dep.Path == "github.com/slackhq/nebula" {
return strings.TrimPrefix(dep.Version, "v")
}
}
return ""
}

View File

@@ -333,12 +333,13 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
} }
fp.Protocol = uint8(proto) fp.Protocol = uint8(proto)
ports := data[offset : offset+4]
if incoming { if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) fp.RemotePort = binary.BigEndian.Uint16(ports[0:2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) fp.LocalPort = binary.BigEndian.Uint16(ports[2:4])
} else { } else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) fp.LocalPort = binary.BigEndian.Uint16(ports[0:2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) fp.RemotePort = binary.BigEndian.Uint16(ports[2:4])
} }
fp.Fragment = false fp.Fragment = false

View File

@@ -13,6 +13,5 @@ type Device interface {
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (io.ReadWriteCloser, error)
} }

View File

@@ -3,7 +3,6 @@ package overlay
import ( import (
"fmt" "fmt"
"math" "math"
"net"
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
@@ -305,29 +304,3 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return routes, nil return routes, nil
} }
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
// Make sure o contains the lowest form of i
if !o.Contains(i.IP.Mask(i.Mask)) {
return false
}
// Find the max ip in i
ip4 := i.IP.To4()
if ip4 == nil {
return false
}
last := make(net.IP, len(ip4))
copy(last, ip4)
for x := range ip4 {
last[x] |= ^i.Mask[x]
}
// Make sure o contains the max
if !o.Contains(last) {
return false
}
return true
}

View File

@@ -225,6 +225,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
// no mtu // no mtu
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Equal(t, 0, routes[0].MTU) assert.Equal(t, 0, routes[0].MTU)
@@ -318,7 +319,7 @@ func Test_makeRouteTree(t *testing.T) {
ip, err = netip.ParseAddr("1.1.0.1") ip, err = netip.ParseAddr("1.1.0.1")
require.NoError(t, err) require.NoError(t, err)
r, ok = routeTree.Lookup(ip) _, ok = routeTree.Lookup(ip)
assert.False(t, ok) assert.False(t, ok)
} }

View File

@@ -1,8 +1,6 @@
package overlay package overlay
import ( import (
"fmt"
"net"
"net/netip" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -72,51 +70,3 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed return removed
} }
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}
func flipBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] ^= 0xFF
}
return b
}
func orBytes(a []byte, b []byte) []byte {
ret := make([]byte, len(a))
for i := 0; i < len(a); i++ {
ret[i] = a[i] | b[i]
}
return ret
}
func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice(
orBytes(
cidr.Addr().AsSlice(),
flipBytes(prefixToMask(cidr).AsSlice()),
),
)
return broadcast
}
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
for _, gateway := range gateways {
if dest.Addr().Is4() && gateway.Addr().Is4() {
return gateway, nil
}
if dest.Addr().Is6() && gateway.Addr().Is6() {
return gateway, nil
}
}
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
}

View File

@@ -95,10 +95,6 @@ func (t *tun) Name() string {
return "android" return "android"
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android") return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
} }

View File

@@ -1,5 +1,5 @@
//go:build !ios && !e2e_testing //go:build darwin && !ios && !e2e_testing
// +build !ios,!e2e_testing // +build darwin,!ios,!e2e_testing
package overlay package overlay
@@ -8,48 +8,27 @@ import (
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"os"
"sync/atomic"
"syscall"
"unsafe" "unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" netroute "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
) )
type tun struct { type tun struct {
io.ReadWriteCloser
Device string
vpnNetworks []netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr linkAddr *netroute.LinkAddr
l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
} }
// ioctl structures for Darwin network configuration
type ifReq struct { type ifReq struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Flags uint16 Flags uint16
pad [8]byte pad [8]byte
} }
const (
_SIOCAIFADDR_IN6 = 2155899162
_UTUN_OPT_IFNAME = 2
_IN6_IFF_NODAD = 0x0020
_IN6_IFF_SECURED = 0x0400
utunControlName = "com.apple.net.utun_control"
)
type ifreqMTU struct { type ifreqMTU struct {
Name [16]byte Name [16]byte
MTU int32 MTU int32
@@ -79,60 +58,61 @@ type ifreqAlias6 struct {
Lifetime addrLifetime Lifetime addrLifetime
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { const (
_SIOCAIFADDR_IN6 = 2155899162
_IN6_IFF_NODAD = 0x0020
)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported on Darwin")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
name := c.GetString("tun.dev", "") name := c.GetString("tun.dev", "")
ifIndex := -1 deviceName := "utun"
// Parse device name to handle utun[0-9]+ format
if name != "" && name != "utun" { if name != "" && name != "utun" {
ifIndex := -1
_, err := fmt.Sscanf(name, "utun%d", &ifIndex) _, err := fmt.Sscanf(name, "utun%d", &ifIndex)
if err != nil || ifIndex < 0 { if err != nil || ifIndex < 0 {
// NOTE: we don't make this error so we don't break existing // NOTE: we don't make this error so we don't break existing
// configs that set a name before it was used. // configs that set a name before it was used.
l.Warn("interface name must be utun[0-9]+ on Darwin, ignoring") l.Warn("interface name must be utun[0-9]+ on Darwin, ignoring")
ifIndex = -1 } else {
deviceName = name
} }
} }
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL) mtu := c.GetInt("tun.mtu", DefaultMTU)
// Create WireGuard TUN device
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
if err != nil { if err != nil {
return nil, fmt.Errorf("system socket: %v", err) return nil, fmt.Errorf("failed to create TUN device: %w", err)
} }
var ctlInfo = &unix.CtlInfo{} // Get the actual device name
copy(ctlInfo.Name[:], utunControlName) actualName, err := tunDevice.Name()
err = unix.IoctlCtlInfo(fd, ctlInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("CTLIOCGINFO: %v", err) tunDevice.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
} }
err = unix.Connect(fd, &unix.SockaddrCtl{ t := &wgTun{
ID: ctlInfo.Id, tunDevice: tunDevice,
Unit: uint32(ifIndex) + 1,
})
if err != nil {
return nil, fmt.Errorf("SYS_CONNECT: %v", err)
}
name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
if err != nil {
return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
}
err = unix.SetNonblock(fd, true)
if err != nil {
return nil, fmt.Errorf("SetNonblock: %v", err)
}
t := &tun{
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
Device: name,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), MaxMTU: mtu,
DefaultMTU: mtu,
l: l, l: l,
} }
// Create Darwin-specific route manager
t.routeManager = &tun{}
err = t.reload(c, true) err = t.reload(c, true)
if err != nil { if err != nil {
tunDevice.Close()
return nil, err return nil, err
} }
@@ -143,215 +123,251 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
}) })
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil return t, nil
} }
func (t *tun) deviceBytes() (o [16]byte) { func (rm *tun) Activate(t *wgTun) error {
for i, c := range t.Device { name, err := t.tunDevice.Name()
o[i] = byte(c)
}
return
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
return nil
}
func (t *tun) Activate() error {
devName := t.deviceBytes()
s, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil { if err != nil {
return fmt.Errorf("failed to get device name: %w", err)
}
// Set the MTU
rm.SetMTU(t, t.MaxMTU)
// Add IP addresses
for _, network := range t.vpnNetworks {
if err := rm.addIP(t, name, network); err != nil {
return err return err
} }
defer unix.Close(s)
fd := uintptr(s)
// Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
return fmt.Errorf("failed to set tun mtu: %v", err)
} }
// Get the device flags // Bring up the interface using ioctl
ifrf := ifReq{Name: devName} if err := rm.bringUpInterface(name); err != nil {
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring up interface: %w", err)
return fmt.Errorf("failed to get tun flags: %s", err)
} }
linkAddr, err := getLinkAddr(t.Device) // Get the link address for routing
linkAddr, err := getLinkAddr(name)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get link address: %w", err)
} }
if linkAddr == nil { if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface") return fmt.Errorf("unable to discover link_addr for tun interface")
} }
t.linkAddr = linkAddr rm.linkAddr = linkAddr
for _, network := range t.vpnNetworks { // Set the routes
if network.Addr().Is4() { if err := rm.AddRoutes(t, false); err != nil {
err = t.activate4(network)
if err != nil {
return err return err
} }
return nil
}
func (rm *tun) bringUpInterface(name string) error {
// Open a socket for ioctl
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
if err != nil {
return fmt.Errorf("failed to create socket: %w", err)
}
defer unix.Close(fd)
// Get current flags
var ifrf ifReq
copy(ifrf.Name[:], name)
if err := ioctl(uintptr(fd), unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to get interface flags: %w", err)
}
// Set IFF_UP and IFF_RUNNING flags
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err := ioctl(uintptr(fd), unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set interface flags: %w", err)
}
return nil
}
func (rm *tun) SetMTU(t *wgTun, mtu int) {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get device name for MTU set")
return
}
// Open a socket for ioctl
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
if err != nil {
t.l.WithError(err).Error("Failed to create socket for MTU set")
return
}
defer unix.Close(fd)
// Prepare the ioctl request
var ifr ifreqMTU
copy(ifr.Name[:], name)
ifr.MTU = int32(mtu)
// Set the MTU using ioctl
if err := ioctl(uintptr(fd), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifr))); err != nil {
t.l.WithError(err).Error("Failed to set tun mtu via ioctl")
}
}
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
// On Darwin, routes are set via ifconfig and route commands
return nil
}
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Install {
continue
}
err := rm.addRoute(r.Cidr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
} else { } else {
err = t.activate6(network) t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
for _, r := range routes {
if !r.Install {
continue
}
err := rm.delRoute(r.Cidr)
if err != nil { if err != nil {
return err t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
} }
} }
} }
// Run the interface func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING // Darwin doesn't support multi-queue TUN devices in the same way as Linux
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { // Return a reader that wraps the same device
return fmt.Errorf("failed to run tun device: %s", err) return &wgTunReader{
parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
} }
// Unsafe path routes func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
return t.addRoutes(false) addr := network.Addr()
if addr.Is4() {
return rm.addIPv4(name, network)
} else {
return rm.addIPv6(name, network)
}
} }
func (t *tun) activate4(network netip.Prefix) error { func (rm *tun) addIPv4(name string, network netip.Prefix) error {
s, err := unix.Socket( // Open an IPv4 socket for ioctl
unix.AF_INET, s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to create IPv4 socket: %w", err)
} }
defer unix.Close(s) defer unix.Close(s)
ifr := ifreqAlias4{ var ifr ifreqAlias4
Name: t.deviceBytes(), copy(ifr.Name[:], name)
Addr: unix.RawSockaddrInet4{
// Set the address
ifr.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4, Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET, Family: unix.AF_INET,
Addr: network.Addr().As4(), Addr: network.Addr().As4(),
}, }
DstAddr: unix.RawSockaddrInet4{
// Set the destination address (same as address for point-to-point)
ifr.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4, Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET, Family: unix.AF_INET,
Addr: network.Addr().As4(), Addr: network.Addr().As4(),
}, }
MaskAddr: unix.RawSockaddrInet4{
// Set the netmask
ifr.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4, Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET, Family: unix.AF_INET,
Addr: prefixToMask(network).As4(), Addr: prefixToMask(network).As4(),
},
} }
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun v4 address: %s", err) return fmt.Errorf("failed to set IPv4 address via ioctl: %w", err)
}
err = addRoute(network, t.linkAddr)
if err != nil {
return err
} }
return nil return nil
} }
func (t *tun) activate6(network netip.Prefix) error { func (rm *tun) addIPv6(name string, network netip.Prefix) error {
s, err := unix.Socket( // Open an IPv6 socket for ioctl
unix.AF_INET6, s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to create IPv6 socket: %w", err)
} }
defer unix.Close(s) defer unix.Close(s)
ifr := ifreqAlias6{ var ifr ifreqAlias6
Name: t.deviceBytes(), copy(ifr.Name[:], name)
Addr: unix.RawSockaddrInet6{
// Set the address
ifr.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6, Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6, Family: unix.AF_INET6,
Addr: network.Addr().As16(), Addr: network.Addr().As16(),
}, }
PrefixMask: unix.RawSockaddrInet6{
// Set the prefix mask
ifr.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6, Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6, Family: unix.AF_INET6,
Addr: prefixToMask(network).As16(), Addr: prefixToMask(network).As16(),
}, }
Lifetime: addrLifetime{
// never expires // Set lifetime (never expires)
ifr.Lifetime = addrLifetime{
Vltime: 0xffffffff, Vltime: 0xffffffff,
Pltime: 0xffffffff, Pltime: 0xffffffff,
},
Flags: _IN6_IFF_NODAD,
} }
// Set flags (no DAD - Duplicate Address Detection)
ifr.Flags = _IN6_IFF_NODAD
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err) return fmt.Errorf("failed to set IPv6 address via ioctl: %w", err)
} }
return nil return nil
} }
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, ok := t.routeTree.Load().Lookup(ip)
if ok {
return r
}
return routing.Gateways{}
}
// Get the LinkAddr for the interface of the given name
// Is there an easier way to fetch this when we create the interface?
// Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
func getLinkAddr(name string) (*netroute.LinkAddr, error) { func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0) rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil { if err != nil {
@@ -377,53 +393,7 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
return nil, nil return nil, nil
} }
func (t *tun) addRoutes(logErrors bool) error { func (rm *tun) addRoute(prefix netip.Prefix) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil { if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
@@ -441,13 +411,13 @@ func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
route.Addrs = []netroute.Addr{ route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway, unix.RTAX_GATEWAY: rm.linkAddr,
} }
} else { } else {
route.Addrs = []netroute.Addr{ route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway, unix.RTAX_GATEWAY: rm.linkAddr,
} }
} }
@@ -464,7 +434,7 @@ func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
return nil return nil
} }
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { func (rm *tun) delRoute(prefix netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil { if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
@@ -481,13 +451,13 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
route.Addrs = []netroute.Addr{ route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway, unix.RTAX_GATEWAY: rm.linkAddr,
} }
} else { } else {
route.Addrs = []netroute.Addr{ route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway, unix.RTAX_GATEWAY: rm.linkAddr,
} }
} }
@@ -495,6 +465,7 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err) return fmt.Errorf("failed to create route.RouteMessage: %w", err)
} }
_, err = unix.Write(sock, data[:]) _, err = unix.Write(sock, data[:])
if err != nil { if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@@ -503,56 +474,34 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
return nil return nil
} }
func (t *tun) Read(to []byte) (int, error) { func ioctl(a1, a2, a3 uintptr) error {
buf := make([]byte, len(to)+4) _, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
n, err := t.ReadWriteCloser.Read(buf) return errno
}
copy(to, buf[4:]) return nil
return n - 4, err
} }
// Write is only valid for single threaded use func prefixToMask(prefix netip.Prefix) netip.Addr {
func (t *tun) Write(from []byte) (int, error) { bits := prefix.Bits()
buf := t.out if prefix.Addr().Is4() {
if cap(buf) < len(from)+4 { // Create IPv4 netmask from prefix length
buf = make([]byte, len(from)+4) mask := ^uint32(0) << (32 - bits)
t.out = buf return netip.AddrFrom4([4]byte{
} byte(mask >> 24),
buf = buf[:len(from)+4] byte(mask >> 16),
byte(mask >> 8),
if len(from) == 0 { byte(mask),
return 0, syscall.EIO })
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else { } else {
return 0, fmt.Errorf("unable to determine IP version from packet") // Create IPv6 netmask from prefix length
var mask [16]byte
for i := 0; i < bits/8; i++ {
mask[i] = 0xff
} }
if bits%8 != 0 {
copy(buf[4:], from) mask[bits/8] = ^byte(0) << (8 - bits%8)
n, err := t.ReadWriteCloser.Write(buf)
return n - 4, err
} }
return netip.AddrFrom16(mask)
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
} }
func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }

View File

@@ -105,10 +105,6 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (t *disabledTun) SupportsMultiqueue() bool {
return true
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil return t, nil
} }

View File

@@ -1,284 +1,77 @@
//go:build !e2e_testing //go:build freebsd && !e2e_testing
// +build !e2e_testing // +build freebsd,!e2e_testing
package overlay package overlay
import ( import (
"bytes"
"errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"net/netip" "net/netip"
"sync/atomic" "os/exec"
"strconv"
"strings"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
) )
const ( type tun struct{}
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678
TUNSIFMODE = 0x8004745e
TUNSIFHEAD = 0x80047460
OSIOCAIFADDR_IN6 = 0x8088691b
IN6_IFF_NODAD = 0x0020
)
type fiodgnameArg struct {
length int32
pad [4]byte
buf unsafe.Pointer
}
// ifreqRename is used for renaming network interfaces on FreeBSD
type ifreqRename struct { type ifreqRename struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Data uintptr Data uintptr
} }
type ifreqDestroy struct { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
Name [unix.IFNAMSIZ]byte return nil, fmt.Errorf("newTunFromFd not supported on FreeBSD")
pad [16]byte
} }
type ifReq struct { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
Name [unix.IFNAMSIZ]byte deviceName := c.GetString("tun.dev", "tun")
Flags uint16 mtu := c.GetInt("tun.mtu", DefaultMTU)
}
type ifreqMTU struct { // Create WireGuard TUN device
Name [unix.IFNAMSIZ]byte tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
MTU int32
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
VHid uint32
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
VHid uint32
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
devFd int
}
func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
return 0, err
} else {
return bytesRead - 4, err
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
return int(n) - 4, err
}
func (t *tun) Close() error {
if t.devFd >= 0 {
err := syscall.Close(t.devFd)
if err != nil { if err != nil {
t.l.WithError(err).Error("Error closing device") return nil, fmt.Errorf("failed to create TUN device: %w", err)
} }
t.devFd = -1
c := make(chan struct{}) // Get the actual device name
go func() { actualName, err := tunDevice.Name()
// destroying the interface can block if a read() is still pending. Do this asynchronously.
defer close(c)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err == nil {
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
}
if err != nil { if err != nil {
t.l.WithError(err).Error("Error destroying tunnel") tunDevice.Close()
} return nil, fmt.Errorf("failed to get TUN device name: %w", err)
}()
// wait up to 1 second so we start blocking at the ioctl
select {
case <-c:
case <-time.After(1 * time.Second):
}
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var fd int
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
}
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
}
if err != nil {
return nil, err
}
// Read the name of the interface
var name [16]byte
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
if ctrlErr == nil {
// set broadcast mode and multicast
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
}
if ctrlErr == nil {
// turn on link-layer mode, to support ipv6
ifhead := uint32(1)
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
}
if ctrlErr != nil {
return nil, err
}
ifName := string(bytes.TrimRight(name[:], "\x00"))
if deviceName == "" {
deviceName = ifName
} }
// If the name doesn't match the desired interface name, rename it now // If the name doesn't match the desired interface name, rename it now
if ifName != deviceName { if actualName != deviceName && deviceName != "" && deviceName != "tun" {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err := renameInterface(actualName, deviceName); err != nil {
if err != nil { tunDevice.Close()
return nil, err return nil, fmt.Errorf("failed to rename interface from %s to %s: %w", actualName, deviceName, err)
} }
defer syscall.Close(s) actualName = deviceName
fd := uintptr(s)
var fromName [16]byte
var toName [16]byte
copy(fromName[:], ifName)
copy(toName[:], deviceName)
ifrr := ifreqRename{
Name: fromName,
Data: uintptr(unsafe.Pointer(&toName)),
} }
// Set the device name t := &wgTun{
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) tunDevice: tunDevice,
}
t := &tun{
Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MaxMTU: mtu,
DefaultMTU: mtu,
l: l, l: l,
devFd: fd,
} }
// Create FreeBSD-specific route manager
t.routeManager = &tun{}
err = t.reload(c, true) err = t.reload(c, true)
if err != nil { if err != nil {
tunDevice.Close()
return nil, err return nil, err
} }
@@ -289,184 +82,86 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
}) })
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil return t, nil
} }
func (t *tun) addIp(cidr netip.Prefix) error { func (rm *tun) Activate(t *wgTun) error {
if cidr.Addr().Is4() { name, err := t.tunDevice.Name()
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: getBroadcast(cidr).As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
},
VHid: 0,
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get device name: %w", err)
}
defer syscall.Close(s)
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
} }
if cidr.Addr().Is6() { // Set the MTU
ifr := ifreqAlias6{ rm.SetMTU(t, t.MaxMTU)
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
},
Lifetime: addrLifetime{
Expire: 0,
Preferred: 0,
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
Flags: IN6_IFF_NODAD,
}
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { // Add IP addresses
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) for _, network := range t.vpnNetworks {
} if err := rm.addIP(t, name, network); err != nil {
return nil
}
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
// Setup our default MTU
err := t.setMTU()
if err != nil {
return err
}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err return err
} }
} }
return t.addRoutes(false) // Bring up the interface
if err := runCommandBSD("ifconfig", name, "up"); err != nil {
return fmt.Errorf("failed to bring up interface: %w", err)
} }
func (t *tun) setMTU() error { // Set the routes
// Set the MTU on the device if err := rm.AddRoutes(t, false); err != nil {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err return err
} }
defer syscall.Close(s)
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil return nil
} }
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { func (rm *tun) SetMTU(t *wgTun, mtu int) {
r, _ := t.routeTree.Load().Lookup(ip) name, err := t.tunDevice.Name()
return r if err != nil {
t.l.WithError(err).Error("Failed to get device name for MTU set")
return
} }
func (t *tun) Networks() []netip.Prefix { if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
return t.vpnNetworks t.l.WithError(err).Error("Failed to set tun mtu")
}
} }
func (t *tun) Name() string { func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
return t.Device // On FreeBSD, routes are set via ifconfig and route commands
return nil
} }
func (t *tun) SupportsMultiqueue() bool { func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
return false name, err := t.tunDevice.Name()
if err != nil {
return fmt.Errorf("failed to get device name: %w", err)
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if len(r.Via) == 0 || !r.Install { if !r.Install {
// We don't allow route MTUs so only install routes with a via
continue continue
} }
err := addRoute(r.Cidr, t.linkAddr) // Add route using route command
args := []string{"add"}
if r.Cidr.Addr().Is6() {
args = append(args, "-inet6")
} else {
args = append(args, "-inet")
}
args = append(args, r.Cidr.String(), "-interface", name)
if r.Metric > 0 {
// FreeBSD doesn't support route metrics directly like Linux
t.l.WithField("route", r).Warn("Route metrics are not fully supported on FreeBSD")
}
err := runCommandBSD("route", args...)
if err != nil { if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
@@ -482,142 +177,99 @@ func (t *tun) addRoutes(logErrors bool) error {
return nil return nil
} }
func (t *tun) removeRoutes(routes []Route) error { func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get device name for route removal")
return
}
for _, r := range routes { for _, r := range routes {
if !r.Install { if !r.Install {
continue continue
} }
err := delRoute(r.Cidr, t.linkAddr) args := []string{"delete"}
if r.Cidr.Addr().Is6() {
args = append(args, "-inet6")
} else {
args = append(args, "-inet")
}
args = append(args, r.Cidr.String(), "-interface", name)
err := runCommandBSD("route", args...)
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.WithField("route", r).Info("Removed route")
} }
} }
return nil
} }
func (t *tun) deviceBytes() (o [16]byte) { func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
for i, c := range t.Device { // FreeBSD doesn't support multi-queue TUN devices in the same way as Linux
o[i] = byte(c) // Return a reader that wraps the same device
} return &wgTunReader{
return parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
} }
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) addr := network.Addr()
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{ if addr.Is4() {
Version: unix.RTM_VERSION, // For IPv4: ifconfig tun0 10.0.0.1/24
Type: unix.RTM_ADD, if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
Flags: unix.RTF_UP, return fmt.Errorf("failed to add IPv4 address: %w", err)
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
} }
} else { } else {
route.Addrs = []netroute.Addr{ // For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, return fmt.Errorf("failed to add IPv6 address: %w", err)
unix.RTAX_GATEWAY: gateway,
} }
} }
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil return nil
} }
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { func runCommandBSD(name string, args ...string) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) cmd := exec.Command(name, args...)
output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
} }
defer unix.Close(sock) return nil
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
} }
if prefix.Addr().Is4() { func renameInterface(fromName, toName string) error {
route.Addrs = []netroute.Addr{ s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, if err != nil {
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, return fmt.Errorf("failed to create socket: %w", err)
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
} }
defer syscall.Close(s)
fd := uintptr(s)
var fromNameBytes [unix.IFNAMSIZ]byte
var toNameBytes [unix.IFNAMSIZ]byte
copy(fromNameBytes[:], fromName)
copy(toNameBytes[:], toName)
ifrr := ifreqRename{
Name: fromNameBytes,
Data: uintptr(unsafe.Pointer(&toNameBytes)),
} }
data, err := route.Marshal() // Set the device name using SIOCSIFNAME ioctl
if err != nil { _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
return fmt.Errorf("failed to create route.RouteMessage: %w", err) if errno != 0 {
} return fmt.Errorf("SIOCSIFNAME ioctl failed: %w", errno)
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
} }
return nil return nil
} }
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}

View File

@@ -151,10 +151,6 @@ func (t *tun) Name() string {
return "iOS" return "iOS"
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
} }

View File

@@ -1,5 +1,5 @@
//go:build !android && !e2e_testing //go:build linux && !android && !e2e_testing
// +build !android,!e2e_testing // +build linux,!android,!e2e_testing
package overlay package overlay
@@ -9,133 +9,105 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"strings"
"sync/atomic"
"time" "time"
"unsafe" "unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
) )
type tun struct { type tun struct {
io.ReadWriteCloser
fd int
Device string
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
deviceIndex int deviceIndex int
ioctlFd uintptr ioctlFd uintptr
txQueueLen int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeChan chan struct{}
useSystemRoutes bool useSystemRoutes bool
useSystemRoutesBufferSize int useSystemRoutesBufferSize int
l *logrus.Logger
} }
func (t *tun) Networks() []netip.Prefix { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*wgTun, error) {
return t.vpnNetworks deviceName := c.GetString("tun.dev", "")
} mtu := c.GetInt("tun.mtu", DefaultMTU)
type ifReq struct { // Create WireGuard TUN device
Name [16]byte tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
Flags uint16
pad [8]byte
}
type ifreqMTU struct {
Name [16]byte
MTU int32
pad [8]byte
}
type ifreqQLEN struct {
Name [16]byte
Value int32
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create TUN device: %w", err)
} }
t.Device = "tun0" // Get the actual device name
actualName, err := tunDevice.Name()
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker) tunDevice.Close()
if os.IsNotExist(err) { return nil, fmt.Errorf("failed to get TUN device name: %w", err)
err = os.MkdirAll("/dev/net", 0755)
if err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
if err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
} }
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) t := &wgTun{
if err != nil { tunDevice: tunDevice,
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
}
} else {
return nil, err
}
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], c.GetString("tun.dev", ""))
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
t.Device = name
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), MaxMTU: mtu,
useSystemRoutes: c.GetBool("tun.use_system_route_table", false), DefaultMTU: mtu,
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
l: l, l: l,
} }
err := t.reload(c, true) // Create Linux-specific route manager
routeManager := &tun{
txQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
}
t.routeManager = routeManager
err = t.reload(c, true)
if err != nil { if err != nil {
tunDevice.Close()
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*wgTun, error) {
// Create TUN device from file descriptor
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
mtu := c.GetInt("tun.mtu", DefaultMTU)
tunDevice, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device from fd: %w", err)
}
t := &wgTun{
tunDevice: tunDevice,
vpnNetworks: vpnNetworks,
MaxMTU: mtu,
DefaultMTU: mtu,
l: l,
}
// Create Linux-specific route manager
routeManager := &tun{
txQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
}
t.routeManager = routeManager
err = t.reload(c, true)
if err != nil {
tunDevice.Close()
return nil, err return nil, err
} }
@@ -149,277 +121,105 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
return t, nil return t, nil
} }
func (t *tun) reload(c *config.C, initial bool) error { func (rm *tun) Activate(t *wgTun) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) name, err := t.tunDevice.Name()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get device name: %w", err)
} }
if !initial && !routeChange && !c.HasChanged("tun.mtu") { if t.routeManager.useSystemRoutes {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, true)
if err != nil {
return err
}
oldDefaultMTU := t.DefaultMTU
oldMaxMTU := t.MaxMTU
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
newMaxMTU := newDefaultMTU
for i, r := range routes {
if r.MTU == 0 {
routes[i].MTU = newDefaultMTU
}
if r.MTU > t.MaxMTU {
newMaxMTU = r.MTU
}
}
t.MaxMTU = newMaxMTU
t.DefaultMTU = newDefaultMTU
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
if oldMaxMTU != newMaxMTU {
t.setMTU()
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
}
if oldDefaultMTU != newDefaultMTU {
for i := range t.vpnNetworks {
err := t.setDefaultRoute(t.vpnNetworks[i])
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
}
}
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// This should never be called since addRoutes should log its own errors in a reload condition
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
}
}
return nil
}
func (t *tun) SupportsMultiqueue() bool {
return true
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
return file, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *tun) Write(b []byte) (int, error) {
var nn int
maximum := len(b)
for {
n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 {
nn += n
}
if nn == len(b) {
return nn, err
}
if err != nil {
return nn, err
}
if n == 0 {
return nn, io.ErrUnexpectedEOF
}
}
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
for i := range al {
if al[i].Equal(x) {
return true
}
}
return false
}
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
func (t *tun) addIPs(link netlink.Link) error {
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
for i := range t.vpnNetworks {
newAddrs[i] = &netlink.Addr{
IPNet: &net.IPNet{
IP: t.vpnNetworks[i].Addr().AsSlice(),
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
},
Label: t.vpnNetworks[i].Addr().Zone(),
}
}
//add all new addresses
for i := range newAddrs {
//AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
}
}
//iterate over remainder, remove whoever shouldn't be there
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to get tun address list: %s", err)
}
for i := range al {
if hasNetlinkAddr(newAddrs, al[i]) {
continue
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
}
}
return nil
}
func (t *tun) Activate() error {
devName := t.deviceBytes()
if t.useSystemRoutes {
t.watchRoutes() t.watchRoutes()
} }
// Get the netlink device
link, err := netlink.LinkByName(name)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
rm.deviceIndex = link.Attrs().Index
// Open socket for ioctl operations
s, err := unix.Socket( s, err := unix.Socket(
unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
unix.IPPROTO_IP, unix.IPPROTO_IP,
) )
if err != nil { if err != nil {
return err return err
} }
t.ioctlFd = uintptr(s) rm.ioctlFd = uintptr(s)
// Set the device name rm.SetMTU(t, t.MaxMTU)
ifrf := ifReq{Name: devName}
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index
// Setup our default MTU
t.setMTU()
// Set the transmit queue length // Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} devName := deviceBytes(name)
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { ifrq := ifreqQLEN{Name: devName, Value: int32(rm.txQueueLen)}
if err = ioctl(t.routeManager.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss // If we can't set the queue length nebula will still work but it may lead to packet loss
t.l.WithError(err).Error("Failed to set tun tx queue length") t.l.WithError(err).Error("Failed to set tun tx queue length")
} }
// Disable IPv6 link-local address generation
const modeNone = 1 const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil { if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation") t.l.WithError(err).Warn("Failed to disable link local address generation")
} }
if err = t.addIPs(link); err != nil { // Add IP addresses
if err = t.routeManager.addIPs(t, link); err != nil {
return err return err
} }
// Bring up the interface // Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP if err = netlink.LinkSetUp(link); err != nil {
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err) return fmt.Errorf("failed to bring the tun device up: %s", err)
} }
//set route MTU // Set route MTU
for i := range t.vpnNetworks { for i := range t.vpnNetworks {
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { if err = t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i]); err != nil {
return fmt.Errorf("failed to set default route MTU: %w", err) return fmt.Errorf("failed to set default route MTU: %w", err)
} }
} }
// Set the routes // Set the routes
if err = t.addRoutes(false); err != nil { if err = t.routeManager.AddRoutes(t, false); err != nil {
return err return err
} }
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err)
}
return nil return nil
} }
func (t *tun) setMTU() { func (rm *tun) SetMTU(t *wgTun, mtu int) {
// Set the MTU on the device name, err := t.tunDevice.Name()
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)} if err != nil {
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { t.l.WithError(err).Error("Failed to get device name for MTU set")
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well return
}
link, err := netlink.LinkByName(name)
if err != nil {
t.l.WithError(err).Error("Failed to get link for MTU set")
return
}
if err := netlink.LinkSetMTU(link, mtu); err != nil {
t.l.WithError(err).Error("Failed to set tun mtu") t.l.WithError(err).Error("Failed to set tun mtu")
} }
} }
func (t *tun) setDefaultRoute(cidr netip.Prefix) error { func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
dr := &net.IPNet{ dr := &net.IPNet{
IP: cidr.Masked().Addr().AsSlice(), IP: cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()), Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
} }
nr := netlink.Route{ nr := netlink.Route{
LinkIndex: t.deviceIndex, LinkIndex: t.routeManager.deviceIndex,
Dst: dr, Dst: dr,
MTU: t.DefaultMTU, MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}), AdvMSS: advMSS(Route{}, t.DefaultMTU, t.MaxMTU),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
Src: net.IP(cidr.Addr().AsSlice()), Src: net.IP(cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL, Protocol: unix.RTPROT_KERNEL,
@@ -429,7 +229,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
err := netlink.RouteReplace(&nr) err := netlink.RouteReplace(&nr)
if err != nil { if err != nil {
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` // Retry twice more
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
err = netlink.RouteReplace(&nr) err = netlink.RouteReplace(&nr)
@@ -447,8 +247,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
return nil return nil
} }
func (t *tun) addRoutes(logErrors bool) error { func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
// Path routes
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if !r.Install { if !r.Install {
@@ -461,10 +260,10 @@ func (t *tun) addRoutes(logErrors bool) error {
} }
nr := netlink.Route{ nr := netlink.Route{
LinkIndex: t.deviceIndex, LinkIndex: t.routeManager.deviceIndex,
Dst: dr, Dst: dr,
MTU: r.MTU, MTU: r.MTU,
AdvMSS: t.advMSS(r), AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
} }
@@ -488,7 +287,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return nil return nil
} }
func (t *tun) removeRoutes(routes []Route) { func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
for _, r := range routes { for _, r := range routes {
if !r.Install { if !r.Install {
continue continue
@@ -500,10 +299,10 @@ func (t *tun) removeRoutes(routes []Route) {
} }
nr := netlink.Route{ nr := netlink.Route{
LinkIndex: t.deviceIndex, LinkIndex: t.routeManager.deviceIndex,
Dst: dr, Dst: dr,
MTU: r.MTU, MTU: r.MTU,
AdvMSS: t.advMSS(r), AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
} }
@@ -520,30 +319,105 @@ func (t *tun) removeRoutes(routes []Route) {
} }
} }
func (t *tun) Name() string { func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
return t.Device // For Linux with WireGuard TUN, we can reuse the same device
// The vectorized I/O will handle batching
return &wgTunReader{
parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
} }
func (t *tun) advMSS(r Route) int { func deviceBytes(name string) [16]byte {
var o [16]byte
for i, c := range name {
if i >= 16 {
break
}
o[i] = byte(c)
}
return o
}
func advMSS(r Route, defaultMTU, maxMTU int) int {
mtu := r.MTU mtu := r.MTU
if r.MTU == 0 { if r.MTU == 0 {
mtu = t.DefaultMTU mtu = defaultMTU
} }
// We only need to set advmss if the route MTU does not match the device MTU // We only need to set advmss if the route MTU does not match the device MTU
if mtu != t.MaxMTU { if mtu != maxMTU {
return mtu - 40 return mtu - 40
} }
return 0 return 0
} }
func (t *tun) watchRoutes() { type ifreqQLEN struct {
Name [16]byte
Value int32
pad [8]byte
}
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
for i := range al {
if al[i].Equal(x) {
return true
}
}
return false
}
func (rm *tun) addIPs(t *wgTun, link netlink.Link) error {
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
for i := range t.vpnNetworks {
newAddrs[i] = &netlink.Addr{
IPNet: &net.IPNet{
IP: t.vpnNetworks[i].Addr().AsSlice(),
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
},
Label: t.vpnNetworks[i].Addr().Zone(),
}
}
// Add all new addresses
for i := range newAddrs {
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
}
}
// Iterate over remainder, remove whoever shouldn't be there
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to get tun address list: %s", err)
}
for i := range al {
if hasNetlinkAddr(newAddrs, al[i]) {
continue
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
}
}
return nil
}
// watchRoutes monitors system route changes
func (t *wgTun) watchRoutes() {
rch := make(chan netlink.RouteUpdate) rch := make(chan netlink.RouteUpdate)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
netlinkOptions := netlink.RouteSubscribeOptions{ netlinkOptions := netlink.RouteSubscribeOptions{
ReceiveBufferSize: t.useSystemRoutesBufferSize, ReceiveBufferSize: t.routeManager.useSystemRoutesBufferSize,
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0, ReceiveBufferForceSize: t.routeManager.useSystemRoutesBufferSize != 0,
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") }, ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
} }
@@ -561,98 +435,19 @@ func (t *tun) watchRoutes() {
if ok { if ok {
t.updateRoutes(r) t.updateRoutes(r)
} else { } else {
// may be should do something here as
// netlink stops sending updates
return return
} }
case <-doneChan: case <-doneChan:
// netlink.RouteSubscriber will close the rch for us
return return
} }
} }
}() }()
} }
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool { func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
withinNetworks := false gateways := t.getGatewaysFromRoute(&r.Route, t.routeManager.deviceIndex)
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
withinNetworks = true
break
}
}
return withinNetworks
}
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways
}
// If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index {
gwAddr, ok := getGatewayAddr(r.Gw, r.Via)
if ok {
if t.isGatewayInVpnNetworks(gwAddr) {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
}
}
for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index {
gwAddr, ok := getGatewayAddr(p.Gw, p.Via)
if ok {
if t.isGatewayInVpnNetworks(gwAddr) {
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
}
}
}
routing.CalculateBucketsForGateways(gateways)
return gateways
}
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
// Try to use the old RTA_GATEWAY first
gwAddr, ok := netip.AddrFromSlice(gw)
if !ok {
// Fallback to the new RTA_VIA
rVia, ok := via.(*netlink.Via)
if ok {
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
}
}
if gwAddr.IsValid() {
gwAddr = gwAddr.Unmap()
return gwAddr, true
}
return netip.Addr{}, false
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 { if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways") t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
return return
} }
@@ -676,7 +471,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
if r.Type == unix.RTM_NEWROUTE { if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
newTree.Insert(dst, gateways) newTree.Insert(dst, gateways)
} else { } else {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
newTree.Delete(dst) newTree.Delete(dst)
@@ -684,18 +478,71 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
t.routeTree.Store(newTree) t.routeTree.Store(newTree)
} }
func (t *tun) Close() error { func (t *wgTun) getGatewaysFromRoute(r *netlink.Route, deviceIndex int) routing.Gateways {
if t.routeChan != nil { var gateways routing.Gateways
close(t.routeChan)
name, err := t.tunDevice.Name()
if err != nil {
t.l.Error("Ignoring route update: failed to get device name")
return gateways
} }
if t.ReadWriteCloser != nil { link, err := netlink.LinkByName(name)
_ = t.ReadWriteCloser.Close() if err != nil {
t.l.WithField("DeviceName", name).Error("Ignoring route update: failed to get link by name")
return gateways
} }
if t.ioctlFd > 0 { // If this route is relevant to our interface and there is a gateway then add it
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close() if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
}
} }
for _, p := range r.MultiPath {
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(p.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
}
}
}
routing.CalculateBucketsForGateways(gateways)
return gateways
}
func (t *wgTun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
return true
}
}
return false
}
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil return nil
} }

View File

@@ -7,25 +7,26 @@ import "testing"
var runAdvMSSTests = []struct { var runAdvMSSTests = []struct {
name string name string
tun *tun defaultMTU int
maxMTU int
r Route r Route
expected int expected int
}{ }{
// Standard case, default MTU is the device max MTU // Standard case, default MTU is the device max MTU
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, {"default", 1440, 1440, Route{}, 0},
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, {"default-min", 1440, 1440, Route{MTU: 1440}, 0},
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, {"default-low", 1440, 1440, Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default // Case where we have a route MTU set higher than the default
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, {"route", 1440, 8941, Route{}, 1400},
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, {"route-min", 1440, 8941, Route{MTU: 1440}, 1400},
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, {"route-high", 1440, 8941, Route{MTU: 8941}, 0},
} }
func TestTunAdvMSS(t *testing.T) { func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests { for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r) o := advMSS(tt.r, tt.defaultMTU, tt.maxMTU)
if o != tt.expected { if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected) t.Errorf("got %d, want %d", o, tt.expected)
} }

View File

@@ -390,10 +390,6 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
} }
@@ -551,3 +547,41 @@ func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
return nil return nil
} }
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil
}
func prefixToMask(prefix netip.Prefix) netip.Addr {
bits := prefix.Bits()
if prefix.Addr().Is4() {
mask := ^uint32(0) << (32 - bits)
return netip.AddrFrom4([4]byte{
byte(mask >> 24),
byte(mask >> 16),
byte(mask >> 8),
byte(mask),
})
}
var mask [16]byte
for i := 0; i < bits/8; i++ {
mask[i] = 0xff
}
if bits%8 != 0 {
mask[bits/8] = ^byte(0) << (8 - bits%8)
}
return netip.AddrFrom16(mask)
}
func selectGateway(prefix netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
for _, gw := range gateways {
if prefix.Addr().Is4() == gw.Addr().Is4() {
return gw, nil
}
}
return netip.Prefix{}, fmt.Errorf("no suitable gateway found for prefix %v", prefix)
}

View File

@@ -1,14 +0,0 @@
//go:build !windows
// +build !windows
package overlay
import "syscall"
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil
}

View File

@@ -1,104 +1,59 @@
//go:build !e2e_testing //go:build openbsd && !e2e_testing
// +build !e2e_testing // +build openbsd,!e2e_testing
package overlay package overlay
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"os" "os/exec"
"regexp" "strconv"
"sync/atomic" "strings"
"syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" wgtun "golang.zx2c4.com/wireguard/tun"
"golang.org/x/sys/unix"
) )
const ( type tun struct{}
SIOCAIFADDR_IN6 = 0x8080691a
)
type ifreqAlias4 struct { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
Name [unix.IFNAMSIZ]byte return nil, fmt.Errorf("newTunFromFd not supported on OpenBSD")
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
} }
type ifreqAlias6 struct { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
Name [unix.IFNAMSIZ]byte deviceName := c.GetString("tun.dev", "tun")
Addr unix.RawSockaddrInet6 mtu := c.GetInt("tun.mtu", DefaultMTU)
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime [2]uint32
}
type ifreq struct { // Create WireGuard TUN device
Name [unix.IFNAMSIZ]byte tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
data int
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
f *os.File
fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create TUN device: %w", err)
} }
err = unix.SetNonblock(fd, true) // Get the actual device name
actualName, err := tunDevice.Name()
if err != nil { if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking") tunDevice.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
} }
t := &tun{ t := &wgTun{
f: os.NewFile(uintptr(fd), ""), tunDevice: tunDevice,
fd: fd,
Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MaxMTU: mtu,
DefaultMTU: mtu,
l: l, l: l,
} }
// Create OpenBSD-specific route manager
t.routeManager = &tun{}
err = t.reload(c, true) err = t.reload(c, true)
if err != nil { if err != nil {
tunDevice.Close()
return nil, err return nil, err
} }
@@ -109,225 +64,86 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
}) })
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil return t, nil
} }
func (t *tun) Close() error { func (rm *tun) Activate(t *wgTun) error {
if t.f != nil { name, err := t.tunDevice.Name()
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.f.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
n, err := t.f.Write(buf)
return n - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() {
var req ifreqAlias4
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil { if err != nil {
return fmt.Errorf("failed to get device name: %w", err)
}
// Set the MTU
rm.SetMTU(t, t.MaxMTU)
// Add IP addresses
for _, network := range t.vpnNetworks {
if err := rm.addIP(t, name, network); err != nil {
return err return err
} }
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
} }
err = addRoute(cidr, t.vpnNetworks) // Bring up the interface
if err != nil { if err := runCommandBSD("ifconfig", name, "up"); err != nil {
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err) return fmt.Errorf("failed to bring up interface: %w", err)
}
// Set the routes
if err := rm.AddRoutes(t, false); err != nil {
return err
} }
return nil return nil
} }
if cidr.Addr().Is6() { func (rm *tun) SetMTU(t *wgTun, mtu int) {
var req ifreqAlias6 name, err := t.tunDevice.Name()
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime[0] = 0xffffffff
req.Lifetime[1] = 0xffffffff
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil { if err != nil {
return err t.l.WithError(err).Error("Failed to get device name for MTU set")
} return
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
} }
if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
t.l.WithError(err).Error("Failed to set tun mtu")
}
}
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
// On OpenBSD, routes are set via ifconfig and route commands
return nil return nil
} }
return fmt.Errorf("unknown address type %v", cidr) func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
} name, err := t.tunDevice.Name()
func (t *tun) Activate() error {
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil { if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err) return fmt.Errorf("failed to get device name: %w", err)
} }
for i := range t.vpnNetworks {
err = t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if len(r.Via) == 0 || !r.Install { if !r.Install {
// We don't allow route MTUs so only install routes with a via
continue continue
} }
err := addRoute(r.Cidr, t.vpnNetworks) // Add route using route command
args := []string{"add"}
if r.Cidr.Addr().Is6() {
args = append(args, "-inet6")
} else {
args = append(args, "-inet")
}
args = append(args, r.Cidr.String(), "-interface", name)
if r.Metric > 0 {
// OpenBSD doesn't support route metrics directly like Linux
t.l.WithField("route", r).Warn("Route metrics are not fully supported on OpenBSD")
}
err := runCommandBSD("route", args...)
if err != nil { if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
@@ -343,131 +159,71 @@ func (t *tun) addRoutes(logErrors bool) error {
return nil return nil
} }
func (t *tun) removeRoutes(routes []Route) error { func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get device name for route removal")
return
}
for _, r := range routes { for _, r := range routes {
if !r.Install { if !r.Install {
continue continue
} }
err := delRoute(r.Cidr, t.vpnNetworks) args := []string{"delete"}
if r.Cidr.Addr().Is6() {
args = append(args, "-inet6")
} else {
args = append(args, "-inet")
}
args = append(args, r.Cidr.String(), "-interface", name)
err := runCommandBSD("route", args...)
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.WithField("route", r).Info("Removed route")
} }
} }
return nil
} }
func (t *tun) deviceBytes() (o [16]byte) { func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
for i, c := range t.Device { // OpenBSD doesn't support multi-queue TUN devices in the same way as Linux
o[i] = byte(c) // Return a reader that wraps the same device
} return &wgTunReader{
return parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
} }
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error { func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) addr := network.Addr()
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{ if addr.Is4() {
Version: unix.RTM_VERSION, // For IPv4: ifconfig tun0 10.0.0.1/24
Type: unix.RTM_ADD, if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
Flags: unix.RTF_UP | unix.RTF_GATEWAY, return fmt.Errorf("failed to add IPv4 address: %w", err)
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
} }
} else { } else {
gw, err := selectGateway(prefix, gateways) // For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
if err != nil { if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
return err return fmt.Errorf("failed to add IPv6 address: %w", err)
} }
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
} }
return nil return nil
} }
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error { func runCommandBSD(name string, args ...string) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) cmd := exec.Command(name, args...)
output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
} }
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil return nil
} }

View File

@@ -132,10 +132,6 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func (t *TestTun) SupportsMultiqueue() bool {
return false
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented") return nil, fmt.Errorf("TODO: multiqueue not implemented")
} }

242
overlay/tun_wg.go Normal file
View File

@@ -0,0 +1,242 @@
//go:build !android && !netbsd && !e2e_testing
// +build !android,!netbsd,!e2e_testing
package overlay
import (
"fmt"
"io"
"net/netip"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
wgtun "golang.zx2c4.com/wireguard/tun"
)
// wgTun wraps a WireGuard TUN device and implements the overlay.Device interface
type wgTun struct {
tunDevice wgtun.Device
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeChan chan struct{}
// Platform-specific route management
routeManager *tun
l *logrus.Logger
}
// BatchReader interface for readers that support vectorized I/O
type BatchReader interface {
BatchRead(buffers [][]byte, sizes []int) (int, error)
}
// BatchWriter interface for writers that support vectorized I/O
type BatchWriter interface {
BatchWrite(packets [][]byte) (int, error)
}
// wgTunReader wraps a single TUN queue for multi-queue support
type wgTunReader struct {
parent *wgTun
tunDevice wgtun.Device
offset int
l *logrus.Logger
}
func (t *wgTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *wgTun) Name() string {
name, err := t.tunDevice.Name()
if err != nil {
t.l.WithError(err).Error("Failed to get TUN device name")
return "unknown"
}
return name
}
func (t *wgTun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *wgTun) Activate() error {
if t.routeManager == nil {
return fmt.Errorf("route manager not initialized")
}
return t.routeManager.Activate(t)
}
// Read implements single-packet read for backward compatibility
func (t *wgTun) Read(b []byte) (int, error) {
bufs := [][]byte{b}
sizes := []int{0}
n, err := t.tunDevice.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrNoProgress
}
return sizes[0], nil
}
// Write implements single-packet write for backward compatibility
func (t *wgTun) Write(b []byte) (int, error) {
bufs := [][]byte{b}
offset := 0
// WireGuard TUN expects the packet data to start at offset 0
n, err := t.tunDevice.Write(bufs, offset)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
func (t *wgTun) Close() error {
if t.routeChan != nil {
close(t.routeChan)
}
if t.tunDevice != nil {
return t.tunDevice.Close()
}
return nil
}
func (t *wgTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
// For WireGuard TUN, we need to create separate TUN device instances for multi-queue
// The platform-specific implementation will handle this
if t.routeManager == nil {
return nil, fmt.Errorf("route manager not initialized for multi-queue reader")
}
return t.routeManager.NewMultiQueueReader(t)
}
func (t *wgTun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, true)
if err != nil {
return err
}
oldDefaultMTU := t.DefaultMTU
oldMaxMTU := t.MaxMTU
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
newMaxMTU := newDefaultMTU
for i, r := range routes {
if r.MTU == 0 {
routes[i].MTU = newDefaultMTU
}
if r.MTU > t.MaxMTU {
newMaxMTU = r.MTU
}
}
t.MaxMTU = newMaxMTU
t.DefaultMTU = newDefaultMTU
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial && t.routeManager != nil {
if oldMaxMTU != newMaxMTU {
t.routeManager.SetMTU(t, t.MaxMTU)
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
}
if oldDefaultMTU != newDefaultMTU {
for i := range t.vpnNetworks {
err := t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i])
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
}
}
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.routeManager.RemoveRoutes(t, findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.routeManager.AddRoutes(t, true)
if err != nil {
// This should never be called since AddRoutes should log its own errors in a reload condition
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
}
}
return nil
}
// BatchRead reads multiple packets from the TUN device using vectorized I/O
// The caller provides buffers and sizes slices, and this function returns the number of packets read.
func (r *wgTunReader) BatchRead(buffers [][]byte, sizes []int) (int, error) {
return r.tunDevice.Read(buffers, sizes, r.offset)
}
// Read implements io.Reader for wgTunReader (single packet for compatibility)
func (r *wgTunReader) Read(b []byte) (int, error) {
bufs := [][]byte{b}
sizes := []int{0}
n, err := r.tunDevice.Read(bufs, sizes, r.offset)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrNoProgress
}
return sizes[0], nil
}
// Write implements io.Writer for wgTunReader
func (r *wgTunReader) Write(b []byte) (int, error) {
bufs := [][]byte{b}
n, err := r.tunDevice.Write(bufs, r.offset)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
// BatchWrite writes multiple packets to the TUN device using vectorized I/O
func (r *wgTunReader) BatchWrite(packets [][]byte) (int, error) {
return r.tunDevice.Write(packets, r.offset)
}
func (r *wgTunReader) Close() error {
if r.tunDevice != nil {
return r.tunDevice.Close()
}
return nil
}

View File

@@ -1,84 +1,77 @@
//go:build !e2e_testing //go:build windows && !e2e_testing
// +build !e2e_testing // +build windows,!e2e_testing
package overlay package overlay
import ( import (
"crypto" "crypto"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
wgtun "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
) )
const tunGUIDLabel = "Fixed Nebula Windows GUID v1" const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct { type tun struct {
Device string luid winipcfg.LUID
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
tun *wintun.NativeTun
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows") return nil, fmt.Errorf("newTunFromFd not supported in Windows")
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
err := checkWinTunExists() deviceName := c.GetString("tun.dev", "Nebula")
mtu := c.GetInt("tun.mtu", DefaultMTU)
// Create WireGuard TUN device
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
if err != nil { if err != nil {
return nil, fmt.Errorf("can not load the wintun driver: %w", err) return nil, fmt.Errorf("failed to create TUN device: %w", err)
} }
deviceName := c.GetString("tun.dev", "") // Get the actual device name
guid, err := generateGUIDByDeviceName(deviceName) actualName, err := tunDevice.Name()
if err != nil { if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err) tunDevice.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
} }
t := &winTun{ t := &wgTun{
Device: deviceName, tunDevice: tunDevice,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MaxMTU: mtu,
DefaultMTU: mtu,
l: l, l: l,
} }
// Create Windows-specific route manager
rm := &tun{}
// Get LUID from the TUN device
// The WireGuard TUN device on Windows should provide a LUID() method
if nativeTun, ok := tunDevice.(interface{ LUID() uint64 }); ok {
rm.luid = winipcfg.LUID(nativeTun.LUID())
} else {
tunDevice.Close()
return nil, fmt.Errorf("failed to get LUID from TUN device")
}
t.routeManager = rm
err = t.reload(c, true) err = t.reload(c, true)
if err != nil { if err != nil {
tunDevice.Close()
return nil, err return nil, err
} }
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false) err := t.reload(c, false)
if err != nil { if err != nil {
@@ -86,210 +79,140 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
}) })
l.WithField("name", actualName).Info("Created WireGuard TUN device")
return t, nil return t, nil
} }
func (t *winTun) reload(c *config.C, initial bool) error { func (rm *tun) Activate(t *wgTun) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) // Set MTU
err := rm.setMTU(t, t.MaxMTU)
if err != nil { if err != nil {
return fmt.Errorf("failed to set MTU: %w", err)
}
// Add IP addresses
for _, network := range t.vpnNetworks {
if err := rm.addIP(t, network); err != nil {
return err return err
} }
if !initial && !change {
return nil
} }
routeTree, err := makeRouteTree(t.l, routes, false) // Add routes
if err != nil { if err := rm.AddRoutes(t, false); err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses(t.vpnNetworks)
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
err = t.addRoutes(false)
if err != nil {
return err return err
} }
return nil return nil
} }
func (t *winTun) addRoutes(logErrors bool) error { func (rm *tun) SetMTU(t *wgTun, mtu int) {
luid := winipcfg.LUID(t.tun.LUID()) if err := rm.setMTU(t, mtu); err != nil {
t.l.WithError(err).Error("Failed to set MTU")
}
}
func (rm *tun) setMTU(t *wgTun, mtu int) error {
// Set MTU using winipcfg
// Note: MTU setting on Windows TUN devices may be handled by the driver
// For now, we'll skip explicit MTU setting as the WireGuard TUN handles it
return nil
}
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
// On Windows, routes are managed differently
return nil
}
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes { for _, r := range routes {
if len(r.Via) == 0 || !r.Install { if !r.Install {
// We don't allow route MTUs so only install routes with a via
continue continue
} }
// Add our unsafe route if r.MTU > 0 {
// Windows does not support multipath routes natively, so we install only a single route. // Windows route MTU is not directly supported
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally. t.l.WithField("route", r).Debug("Route MTU is not supported on Windows")
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. }
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
// Use winipcfg to add the route
// The rm.luid should have the AddRoute method from winipcfg
if len(r.Via) == 0 {
t.l.WithField("route", r).Warn("Route has no via address, skipping")
continue
}
err := rm.luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
if err != nil { if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
continue
} else { } else {
return retErr return retErr
} }
} else { } else {
t.l.WithField("route", r).Info("Added route") t.l.WithField("route", r).Info("Added route")
} }
if !foundDefault4 {
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true
}
}
} }
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return fmt.Errorf("failed to get ip interface: %w", err)
}
ipif.NLMTU = uint32(t.MTU)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
return nil return nil
} }
func (t *winTun) removeRoutes(routes []Route) error { func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes { for _, r := range routes {
if !r.Install { if !r.Install {
continue continue
} }
// See comment on luid.AddRoute if len(r.Via) == 0 {
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) continue
}
err := rm.luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.WithField("route", r).Info("Removed route")
} }
} }
}
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
// Windows doesn't support multi-queue TUN devices
// Return a reader that wraps the same device
return &wgTunReader{
parent: t,
tunDevice: t.tunDevice,
offset: 0,
l: t.l,
}, nil
}
func (rm *tun) addIP(t *wgTun, network netip.Prefix) error {
// Add IP address using winipcfg
// SetIPAddresses expects a slice of prefixes
err := rm.luid.SetIPAddresses([]netip.Prefix{network})
if err != nil {
return fmt.Errorf("failed to add IP address %s: %w", network, err)
}
return nil return nil
} }
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways { // generateGUIDByDeviceName generates a GUID based on the device name
r, _ := t.routeTree.Load().Lookup(ip) func generateGUIDByDeviceName(deviceName string) (*windows.GUID, error) {
return r // Hash the device name to create a deterministic GUID
h := crypto.SHA256.New()
h.Write([]byte(tunGUIDLabel))
h.Write([]byte(deviceName))
sum := h.Sum(nil)
guid := &windows.GUID{
Data1: binary.LittleEndian.Uint32(sum[0:4]),
Data2: binary.LittleEndian.Uint16(sum[4:6]),
Data3: binary.LittleEndian.Uint16(sum[6:8]),
} }
copy(guid.Data4[:], sum[8:16])
func (t *winTun) Networks() []netip.Prefix { return guid, nil
return t.vpnNetworks
}
func (t *winTun) Name() string {
return t.Device
}
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) SupportsMultiqueue() bool {
return false
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
_ = luid.FlushDNS(windows.AF_INET)
_ = luid.FlushDNS(windows.AF_INET6)
return t.tun.Close()
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
_, err := hash.Write([]byte(tunGUIDLabel))
if err != nil {
return nil, err
}
_, err = hash.Write([]byte(name))
if err != nil {
return nil, err
}
sum := hash.Sum(nil)
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
func checkWinTunExists() error {
myPath, err := os.Executable()
if err != nil {
return err
}
arch := runtime.GOARCH
switch arch {
case "386":
//NOTE: wintun bundles 386 as x86
arch = "x86"
}
_, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll"))
return err
} }

View File

@@ -46,10 +46,6 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)} return routing.Gateways{routing.NewGateway(ip, 1)}
} }
func (d *UserDevice) SupportsMultiqueue() bool {
return true
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil return d, nil
} }

47
pki.go
View File

@@ -100,36 +100,41 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
currentState := p.cs.Load() currentState := p.cs.Load()
if newState.v1Cert != nil { if newState.v1Cert != nil {
if currentState.v1Cert == nil { if currentState.v1Cert == nil {
//adding certs is fine, actually. Networks-in-common confirmed in newCertState(). return util.NewContextualError("v1 certificate was added, restart required", nil, err)
} else { }
// did IP in cert change? if so, don't set // did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError( return util.NewContextualError(
"Networks in new cert was different from old", "Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1}, m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
nil, nil,
) )
} }
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError( return util.NewContextualError(
"Curve in new v1 cert was different from old", "Curve in new cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1}, m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
nil, nil,
) )
} }
}
} else if currentState.v1Cert != nil {
//TODO: CERT-V2 we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
} }
if newState.v2Cert != nil { if newState.v2Cert != nil {
if currentState.v2Cert == nil { if currentState.v2Cert == nil {
//adding certs is fine, actually return util.NewContextualError("v2 certificate was added, restart required", nil, err)
} else { }
// did IP in cert change? if so, don't set // did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError( return util.NewContextualError(
"Networks in new cert was different from old", "Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2}, m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
nil, nil,
) )
} }
@@ -137,25 +142,13 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError( return util.NewContextualError(
"Curve in new cert was different from old", "Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2}, m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
nil, nil,
) )
} }
}
} else if currentState.v2Cert != nil { } else if currentState.v2Cert != nil {
//newState.v1Cert is non-nil bc empty certstates aren't permitted return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
if newState.v1Cert == nil {
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
}
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
nil,
)
}
} }
// Cipher cant be hot swapped so just leave it at what it was before // Cipher cant be hot swapped so just leave it at what it was before
@@ -523,14 +516,10 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
} }
bl := c.GetStringSlice("pki.blocklist", []string{}) for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
if len(bl) > 0 { l.WithField("fingerprint", fp).Info("Blocklisting cert")
for _, fp := range bl {
caPool.BlocklistFingerprint(fp) caPool.BlocklistFingerprint(fp)
} }
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
}
return caPool, nil return caPool, nil
} }

View File

@@ -16,8 +16,8 @@ import (
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"go.yaml.in/yaml/v3"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
) )
type m = map[string]any type m = map[string]any

View File

@@ -34,10 +34,6 @@ func (NoopTun) Write([]byte) (int, error) {
return 0, nil return 0, nil
} }
func (NoopTun) SupportsMultiqueue() bool {
return false
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, errors.New("unsupported") return nil, errors.New("unsupported")
} }

View File

@@ -19,7 +19,6 @@ type Conn interface {
ListenOut(r EncReader) ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
SupportsMultipleReaders() bool
Close() error Close() error
} }
@@ -34,9 +33,6 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
func (NoopConn) ListenOut(_ EncReader) { func (NoopConn) ListenOut(_ EncReader) {
return return
} }
func (NoopConn) SupportsMultipleReaders() bool {
return false
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }

View File

@@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
return ErrInvalidIPv6RemoteForSocket return ErrInvalidIPv6RemoteForSocket
} }
var rsa unix.RawSockaddrInet4 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As4() rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa) sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4 addrLen = syscall.SizeofSockaddrInet4
@@ -184,10 +184,6 @@ func (u *StdConn) ListenOut(r EncReader) {
} }
} }
func (u *StdConn) SupportsMultipleReaders() bool {
return false
}
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
var err error var err error
if u.isV4 { if u.isV4 {

View File

@@ -85,7 +85,3 @@ func (u *GenericConn) ListenOut(r EncReader) {
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
} }
} }
func (u *GenericConn) SupportsMultipleReaders() bool {
return false
}

View File

@@ -72,10 +72,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
} }
func (u *StdConn) SupportsMultipleReaders() bool {
return true
}
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
return nil return nil
} }

View File

@@ -315,10 +315,6 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
} }
func (u *RIOConn) SupportsMultipleReaders() bool {
return false
}
func (u *RIOConn) Rebind() error { func (u *RIOConn) Rebind() error {
return nil return nil
} }

View File

@@ -127,10 +127,6 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil return u.Addr, nil
} }
func (u *TesterConn) SupportsMultipleReaders() bool {
return false
}
func (u *TesterConn) Rebind() error { func (u *TesterConn) Rebind() error {
return nil return nil
} }