mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 00:15:37 +01:00
Compare commits
18 Commits
jay.wren-w
...
changelog-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee8e4d2017 | ||
|
|
8d656fb890 | ||
|
|
27ea667aee | ||
|
|
4df8bcb1f5 | ||
|
|
36c890eaad | ||
|
|
44001244f2 | ||
|
|
a89f95182c | ||
|
|
6a8a2992ff | ||
|
|
3d94dfe6a1 | ||
|
|
3670e24fa0 | ||
|
|
b348ee726e | ||
|
|
a941b65114 | ||
|
|
17101d425f | ||
|
|
52f1908126 | ||
|
|
48f1ae98ba | ||
|
|
97b3972c11 | ||
|
|
0f305d5397 | ||
|
|
01909f4715 |
2
.github/workflows/gofmt.yml
vendored
2
.github/workflows/gofmt.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
|
||||
20
.github/workflows/release.yml
vendored
20
.github/workflows/release.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
name: Build Linux/BSD All
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
mv build/*.tar.gz release
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: linux-latest
|
||||
path: release
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
name: Build Windows
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
mv dist\windows\wintun build\dist\windows\
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: windows-latest
|
||||
path: build
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
@@ -104,7 +104,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: darwin-latest
|
||||
path: ./release/*
|
||||
@@ -124,11 +124,11 @@ jobs:
|
||||
# be overwritten
|
||||
- name: Checkout code
|
||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Download artifacts
|
||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: linux-latest
|
||||
path: artifacts
|
||||
@@ -160,10 +160,10 @@ jobs:
|
||||
needs: [build-linux, build-darwin, build-windows]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
|
||||
2
.github/workflows/smoke-extra.yml
vendored
2
.github/workflows/smoke-extra.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
|
||||
2
.github/workflows/smoke.yml
vendored
2
.github/workflows/smoke.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
|
||||
16
.github/workflows/test.yml
vendored
16
.github/workflows/test.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
@@ -32,7 +32,7 @@ jobs:
|
||||
run: make vet
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.5
|
||||
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
- name: Build test mobile
|
||||
run: make build-test-mobile
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: e2e packet flow linux-latest
|
||||
path: e2e/mermaid/linux-latest
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
@@ -77,7 +77,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
os: [windows-latest, macos-latest]
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
@@ -115,7 +115,7 @@ jobs:
|
||||
run: make vet
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.5
|
||||
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
- name: End 2 end
|
||||
run: make e2evv
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: e2e packet flow ${{ matrix.os }}
|
||||
path: e2e/mermaid/${{ matrix.os }}
|
||||
|
||||
60
CHANGELOG.md
60
CHANGELOG.md
@@ -7,12 +7,64 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [1.10.0] - ????
|
||||
|
||||
### Added
|
||||
|
||||
- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153)
|
||||
- ASN.1 based v2 nebula certificates with support for ipv6 and multiple ip addresses.
|
||||
Certificates now have a unified interface for external implementations. (#1212, #1216, #1345)
|
||||
**TODO: External documentation link!**
|
||||
- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331)
|
||||
- Add ECMP support for `unsafe_routes`. (#1332)
|
||||
|
||||
### Changed
|
||||
|
||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
||||
`local_cidr` field. This is almost always the intended behavior. This flag is
|
||||
deprecated and will be removed in a future release.
|
||||
deprecated and will be removed in a future release. (#1373)
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fix moving a udp address from one vpn address to another in the `static_host_map`
|
||||
which could cause rapid re-handshaking with an incorrect remote. (#1259)
|
||||
- Improve smoke tests in environments where the docker network is not the default. (#1347)
|
||||
|
||||
## [1.9.7] - 2025-10-10
|
||||
|
||||
### Security
|
||||
|
||||
- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's
|
||||
certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494)
|
||||
|
||||
### Changed
|
||||
|
||||
- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459)
|
||||
- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1543)
|
||||
|
||||
## [1.9.6] - 2025-7-15
|
||||
|
||||
### Added
|
||||
|
||||
- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fix Darwin freeze due to presence of some Network Extensions (#1426)
|
||||
- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
|
||||
- Fix Windows freeze due to ICMP error handling (#1412)
|
||||
- Fix relay migration panic (#1403)
|
||||
|
||||
## [1.9.5] - 2024-12-05
|
||||
|
||||
### Added
|
||||
|
||||
- Gracefully ignore v2 certificates. (#1282)
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
|
||||
|
||||
## [1.9.4] - 2024-09-09
|
||||
|
||||
@@ -671,7 +723,11 @@ created.)
|
||||
|
||||
- Initial public release.
|
||||
|
||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD
|
||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD
|
||||
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
|
||||
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
|
||||
[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
|
||||
[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
|
||||
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
|
||||
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
|
||||
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2
|
||||
|
||||
109
bits.go
109
bits.go
@@ -9,14 +9,13 @@ type Bits struct {
|
||||
length uint64
|
||||
current uint64
|
||||
bits []bool
|
||||
firstSeen bool
|
||||
lostCounter metrics.Counter
|
||||
dupeCounter metrics.Counter
|
||||
outOfWindowCounter metrics.Counter
|
||||
}
|
||||
|
||||
func NewBits(bits uint64) *Bits {
|
||||
return &Bits{
|
||||
b := &Bits{
|
||||
length: bits,
|
||||
bits: make([]bool, bits, bits),
|
||||
current: 0,
|
||||
@@ -24,34 +23,37 @@ func NewBits(bits uint64) *Bits {
|
||||
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
|
||||
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
|
||||
}
|
||||
|
||||
// There is no counter value 0, mark it to avoid counting a lost packet later.
|
||||
b.bits[0] = true
|
||||
b.current = 0
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
||||
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
||||
// If i is the next number, return true.
|
||||
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
|
||||
if i > b.current {
|
||||
return true
|
||||
}
|
||||
|
||||
// If i is within the window, check if it's been set already. The first window will fail this check
|
||||
if i > b.current-b.length {
|
||||
return !b.bits[i%b.length]
|
||||
}
|
||||
|
||||
// If i is within the first window
|
||||
if i < b.length {
|
||||
// If i is within the window, check if it's been set already.
|
||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
||||
return !b.bits[i%b.length]
|
||||
}
|
||||
|
||||
// Not within the window
|
||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
// If i is the next number, return true and update current.
|
||||
if i == b.current+1 {
|
||||
// Report missed packets, we can only understand what was missed after the first window has been gone through
|
||||
if i > b.length && b.bits[i%b.length] == false {
|
||||
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
||||
// The very first window can only be tracked as lost once we are on the 2nd window or greater
|
||||
if b.bits[i%b.length] == false && i > b.length {
|
||||
b.lostCounter.Inc(1)
|
||||
}
|
||||
b.bits[i%b.length] = true
|
||||
@@ -59,61 +61,32 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// If i packet is greater than current but less than the maximum length of our bitmap,
|
||||
// flip everything in between to false and move ahead.
|
||||
if i > b.current && i < b.current+b.length {
|
||||
// In between current and i need to be zero'd to allow those packets to come in later
|
||||
for n := b.current + 1; n < i; n++ {
|
||||
// If i is a jump, adjust the window, record lost, update current, and return true
|
||||
if i > b.current {
|
||||
lost := int64(0)
|
||||
// Zero out the bits between the current and the new counter value, limited by the window size,
|
||||
// since the window is shifting
|
||||
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
|
||||
if b.bits[n%b.length] == false && n > b.length {
|
||||
lost++
|
||||
}
|
||||
b.bits[n%b.length] = false
|
||||
}
|
||||
|
||||
b.bits[i%b.length] = true
|
||||
b.current = i
|
||||
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
|
||||
return true
|
||||
}
|
||||
|
||||
// If i is greater than the delta between current and the total length of our bitmap,
|
||||
// just flip everything in the map and move ahead.
|
||||
if i >= b.current+b.length {
|
||||
// The current window loss will be accounted for later, only record the jump as loss up until then
|
||||
lost := maxInt64(0, int64(i-b.current-b.length))
|
||||
//TODO: explain this
|
||||
if b.current == 0 {
|
||||
lost++
|
||||
}
|
||||
|
||||
for n := range b.bits {
|
||||
// Don't want to count the first window as a loss
|
||||
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
|
||||
//if b.bits[n] == false {
|
||||
// lost++
|
||||
//}
|
||||
b.bits[n] = false
|
||||
}
|
||||
|
||||
// Only record any skipped packets as a result of the window moving further than the window length
|
||||
// Any loss within the new window will be accounted for in future calls
|
||||
lost += max(0, int64(i-b.current-b.length))
|
||||
b.lostCounter.Inc(lost)
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
|
||||
Debug("Receive window")
|
||||
}
|
||||
b.bits[i%b.length] = true
|
||||
b.current = i
|
||||
return true
|
||||
}
|
||||
|
||||
// Allow for the 0 packet to come in within the first window
|
||||
if i == 0 && b.firstSeen == false && b.current < b.length {
|
||||
b.firstSeen = true
|
||||
b.bits[i%b.length] = true
|
||||
return true
|
||||
}
|
||||
|
||||
// If i is within the window of current minus length (the total pat window size),
|
||||
// allow it and flip to true but to NOT change current. We also have to account for the first window
|
||||
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
|
||||
if b.current == i {
|
||||
// If i is within the current window but below the current counter,
|
||||
// Check to see if it's a duplicate
|
||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
||||
if b.current == i || b.bits[i%b.length] == true {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
||||
Debug("Receive window")
|
||||
@@ -122,18 +95,8 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if b.bits[i%b.length] == true {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
|
||||
Debug("Receive window")
|
||||
}
|
||||
b.dupeCounter.Inc(1)
|
||||
return false
|
||||
}
|
||||
|
||||
b.bits[i%b.length] = true
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// In all other cases, fail and don't change current.
|
||||
@@ -147,11 +110,3 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func maxInt64(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
109
bits_test.go
109
bits_test.go
@@ -15,48 +15,41 @@ func TestBits(t *testing.T) {
|
||||
assert.Len(t, b.bits, 10)
|
||||
|
||||
// This is initialized to zero - receive one. This should work.
|
||||
|
||||
assert.True(t, b.Check(l, 1))
|
||||
u := b.Update(l, 1)
|
||||
assert.True(t, u)
|
||||
assert.True(t, b.Update(l, 1))
|
||||
assert.EqualValues(t, 1, b.current)
|
||||
g := []bool{false, true, false, false, false, false, false, false, false, false}
|
||||
g := []bool{true, true, false, false, false, false, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Receive two
|
||||
assert.True(t, b.Check(l, 2))
|
||||
u = b.Update(l, 2)
|
||||
assert.True(t, u)
|
||||
assert.True(t, b.Update(l, 2))
|
||||
assert.EqualValues(t, 2, b.current)
|
||||
g = []bool{false, true, true, false, false, false, false, false, false, false}
|
||||
g = []bool{true, true, true, false, false, false, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Receive two again - it will fail
|
||||
assert.False(t, b.Check(l, 2))
|
||||
u = b.Update(l, 2)
|
||||
assert.False(t, u)
|
||||
assert.False(t, b.Update(l, 2))
|
||||
assert.EqualValues(t, 2, b.current)
|
||||
|
||||
// Jump ahead to 15, which should clear everything and set the 6th element
|
||||
assert.True(t, b.Check(l, 15))
|
||||
u = b.Update(l, 15)
|
||||
assert.True(t, u)
|
||||
assert.True(t, b.Update(l, 15))
|
||||
assert.EqualValues(t, 15, b.current)
|
||||
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Mark 14, which is allowed because it is in the window
|
||||
assert.True(t, b.Check(l, 14))
|
||||
u = b.Update(l, 14)
|
||||
assert.True(t, u)
|
||||
assert.True(t, b.Update(l, 14))
|
||||
assert.EqualValues(t, 15, b.current)
|
||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Mark 5, which is not allowed because it is not in the window
|
||||
assert.False(t, b.Check(l, 5))
|
||||
u = b.Update(l, 5)
|
||||
assert.False(t, u)
|
||||
assert.False(t, b.Update(l, 5))
|
||||
assert.EqualValues(t, 15, b.current)
|
||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
@@ -69,10 +62,29 @@ func TestBits(t *testing.T) {
|
||||
|
||||
// Walk through a few windows in order
|
||||
b = NewBits(10)
|
||||
for i := uint64(0); i <= 100; i++ {
|
||||
for i := uint64(1); i <= 100; i++ {
|
||||
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
||||
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
||||
}
|
||||
|
||||
assert.False(t, b.Check(l, 1), "Out of window check")
|
||||
}
|
||||
|
||||
func TestBitsLargeJumps(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
|
||||
b = NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
|
||||
assert.Equal(t, int64(45), b.lostCounter.Count())
|
||||
|
||||
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
|
||||
assert.Equal(t, int64(89), b.lostCounter.Count())
|
||||
|
||||
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
|
||||
assert.Equal(t, int64(188), b.lostCounter.Count())
|
||||
}
|
||||
|
||||
func TestBitsDupeCounter(t *testing.T) {
|
||||
@@ -124,8 +136,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||
assert.False(t, b.Update(l, 0))
|
||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||
|
||||
//tODO: make sure lostcounter doesn't increase in orderly increment
|
||||
assert.Equal(t, int64(20), b.lostCounter.Count())
|
||||
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||
}
|
||||
@@ -137,8 +148,6 @@ func TestBitsLostCounter(t *testing.T) {
|
||||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
//assert.True(t, b.Update(0))
|
||||
assert.True(t, b.Update(l, 0))
|
||||
assert.True(t, b.Update(l, 20))
|
||||
assert.True(t, b.Update(l, 21))
|
||||
assert.True(t, b.Update(l, 22))
|
||||
@@ -149,7 +158,7 @@ func TestBitsLostCounter(t *testing.T) {
|
||||
assert.True(t, b.Update(l, 27))
|
||||
assert.True(t, b.Update(l, 28))
|
||||
assert.True(t, b.Update(l, 29))
|
||||
assert.Equal(t, int64(20), b.lostCounter.Count())
|
||||
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
|
||||
@@ -158,8 +167,6 @@ func TestBitsLostCounter(t *testing.T) {
|
||||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
assert.True(t, b.Update(l, 0))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 9))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
// 10 will set 0 index, 0 was already set, no lost packets
|
||||
@@ -214,6 +221,62 @@ func TestBitsLostCounter(t *testing.T) {
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
}
|
||||
|
||||
func TestBitsLostCounterIssue1(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
assert.True(t, b.Update(l, 4))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 1))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 9))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 2))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 3))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 5))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 6))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 7))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
// assert.True(t, b.Update(l, 8))
|
||||
assert.True(t, b.Update(l, 10))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 11))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
|
||||
assert.True(t, b.Update(l, 14))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
|
||||
assert.True(t, b.Update(l, 19))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 12))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 13))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 15))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 16))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 17))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 18))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 20))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 21))
|
||||
|
||||
// We missed packet 8 above
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
}
|
||||
|
||||
func BenchmarkBits(b *testing.B) {
|
||||
z := NewBits(10)
|
||||
for n := 0; n < b.N; n++ {
|
||||
|
||||
@@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||
}
|
||||
|
||||
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
||||
nc := &cert.TBSCertificate{
|
||||
Version: v,
|
||||
Curve: c.Curve(),
|
||||
Name: c.Name(),
|
||||
Networks: c.Networks(),
|
||||
UnsafeNetworks: c.UnsafeNetworks(),
|
||||
Groups: c.Groups(),
|
||||
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
||||
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
||||
PublicKey: c.PublicKey(),
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, pem
|
||||
}
|
||||
|
||||
func X25519Keypair() ([]byte, []byte) {
|
||||
privkey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||
|
||||
@@ -173,23 +173,26 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
|
||||
var passphrase []byte
|
||||
if !isP11 && *cf.encryption {
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
|
||||
if err == ErrNoTerminal {
|
||||
return fmt.Errorf("out-key must be encrypted interactively")
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("error reading passphrase: %s", err)
|
||||
}
|
||||
|
||||
if len(passphrase) > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
||||
if len(passphrase) == 0 {
|
||||
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
|
||||
if err == ErrNoTerminal {
|
||||
return fmt.Errorf("out-key must be encrypted interactively")
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("error reading passphrase: %s", err)
|
||||
}
|
||||
|
||||
if len(passphrase) > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(passphrase) == 0 {
|
||||
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -171,6 +171,17 @@ func Test_ca(t *testing.T) {
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// test encrypted key with passphrase environment variable
|
||||
os.Remove(keyF.Name())
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
||||
require.NoError(t, ca(args, ob, eb, testpw))
|
||||
assert.Empty(t, eb.String())
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||
|
||||
// read encrypted key file and verify default params
|
||||
rb, _ = os.ReadFile(keyF.Name())
|
||||
k, _ := pem.Decode(rb)
|
||||
|
||||
@@ -5,10 +5,28 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A version string that can be set with
|
||||
//
|
||||
// -ldflags "-X main.Build=SOMEVERSION"
|
||||
//
|
||||
// at compile-time.
|
||||
var Build string
|
||||
|
||||
func init() {
|
||||
if Build == "" {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
Build = strings.TrimPrefix(info.Main.Version, "v")
|
||||
}
|
||||
}
|
||||
|
||||
type helpError struct {
|
||||
s string
|
||||
}
|
||||
|
||||
@@ -116,26 +116,28 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
// naively attempt to decode the private key as though it is not encrypted
|
||||
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
|
||||
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
|
||||
// ask for a passphrase until we get one
|
||||
var passphrase []byte
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
|
||||
if errors.Is(err, ErrNoTerminal) {
|
||||
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("error reading password: %s", err)
|
||||
}
|
||||
|
||||
if len(passphrase) > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
||||
if len(passphrase) == 0 {
|
||||
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
|
||||
}
|
||||
// ask for a passphrase until we get one
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
|
||||
if errors.Is(err, ErrNoTerminal) {
|
||||
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("error reading password: %s", err)
|
||||
}
|
||||
|
||||
if len(passphrase) > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(passphrase) == 0 {
|
||||
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
|
||||
}
|
||||
}
|
||||
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
|
||||
|
||||
@@ -379,6 +379,15 @@ func Test_signCert(t *testing.T) {
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// test with the proper password in the environment
|
||||
os.Remove(crtF.Name())
|
||||
os.Remove(keyF.Name())
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
||||
require.NoError(t, signCert(args, ob, eb, testpw))
|
||||
assert.Empty(t, eb.String())
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||
|
||||
// test with the wrong password
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
@@ -389,6 +398,17 @@ func Test_signCert(t *testing.T) {
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// test with the wrong password in environment
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||
|
||||
// test with the user not entering a password
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
@@ -18,6 +20,17 @@ import (
|
||||
// at compile-time.
|
||||
var Build string
|
||||
|
||||
func init() {
|
||||
if Build == "" {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
Build = strings.TrimPrefix(info.Main.Version, "v")
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
serviceFlag := flag.String("service", "", "Control the system service.")
|
||||
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
@@ -18,6 +20,17 @@ import (
|
||||
// at compile-time.
|
||||
var Build string
|
||||
|
||||
func init() {
|
||||
if Build == "" {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
Build = strings.TrimPrefix(info.Main.Version, "v")
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
||||
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
"go.yaml.in/yaml/v3"
|
||||
)
|
||||
|
||||
type C struct {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
"go.yaml.in/yaml/v3"
|
||||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
|
||||
@@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
|
||||
if mainHostInfo {
|
||||
decision = tryRehandshake
|
||||
|
||||
} else {
|
||||
if cm.shouldSwapPrimary(hostinfo) {
|
||||
decision = swapPrimary
|
||||
@@ -461,6 +460,10 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
||||
}
|
||||
|
||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||
if crt == nil {
|
||||
//my cert was reloaded away. We should definitely swap from this tunnel
|
||||
return true
|
||||
}
|
||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||
// settle down.
|
||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||
@@ -475,31 +478,34 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||
cm.hostMap.Unlock()
|
||||
}
|
||||
|
||||
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||
// check and return true.
|
||||
// isInvalidCertificate decides if we should destroy a tunnel.
|
||||
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
||||
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||
remoteCert := hostinfo.GetCert()
|
||||
if remoteCert == nil {
|
||||
return false
|
||||
return false //don't tear down tunnels for handshakes in progress
|
||||
}
|
||||
|
||||
caPool := cm.intf.pki.GetCAPool()
|
||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||
return false //cert is still valid! yay!
|
||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||
// Block listed certificates should always be disconnected
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
||||
return true
|
||||
} else if cm.intf.disconnectInvalid.Load() {
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
return true
|
||||
} else {
|
||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||
return false
|
||||
}
|
||||
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
@@ -530,15 +536,45 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
cs := cm.intf.pki.getCertState()
|
||||
curCrt := hostinfo.ConnectionState.myCert
|
||||
myCrt := cs.getCertificate(curCrt.Version())
|
||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||
curCrtVersion := curCrt.Version()
|
||||
myCrt := cs.getCertificate(curCrtVersion)
|
||||
if myCrt == nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("reason", "local certificate removed").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
peerCrt := hostinfo.ConnectionState.peerCert
|
||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
if curCrtVersion < cs.initiatingVersion {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,6 +174,10 @@ func (c *Control) GetHostmap() *HostMap {
|
||||
return c.f.hostMap
|
||||
}
|
||||
|
||||
func (c *Control) GetF() *Interface {
|
||||
return c.f
|
||||
}
|
||||
|
||||
func (c *Control) GetCertState() *CertState {
|
||||
return c.f.pki.getCertState()
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
"go.yaml.in/yaml/v3"
|
||||
)
|
||||
|
||||
func BenchmarkHotPath(b *testing.B) {
|
||||
@@ -97,6 +97,41 @@ func TestGoodHandshake(t *testing.T) {
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestGoodHandshakeNoOverlap(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
|
||||
|
||||
// Put their info in our lighthouse
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
empty := []byte{}
|
||||
t.Log("do something to cause a handshake")
|
||||
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
|
||||
|
||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||
|
||||
t.Log("Get their stage 1 packet")
|
||||
stage1Packet := theirControl.GetFromUDP(true)
|
||||
|
||||
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
|
||||
myControl.InjectUDPPacket(stage1Packet)
|
||||
|
||||
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
|
||||
myControl.WaitForType(header.Test, 0, theirControl)
|
||||
|
||||
t.Log("Make sure our host infos are correct")
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestWrongResponderHandshake(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
@@ -464,6 +499,35 @@ func TestRelays(t *testing.T) {
|
||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||
}
|
||||
|
||||
func TestRelaysDontCareAboutIps(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
relayControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Trigger a handshake from me to them via the relay")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
r.Log("Assert the tunnel works")
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||
}
|
||||
|
||||
func TestReestablishRelays(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||
@@ -1227,3 +1291,109 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
||||
|
||||
o := m{
|
||||
"static_host_map": m{
|
||||
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
|
||||
},
|
||||
"lighthouse": m{
|
||||
"hosts": []string{lhVpnIpNet[0].Addr().String()},
|
||||
"local_allow_list": m{
|
||||
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
|
||||
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
|
||||
"10.0.0.0/24": true,
|
||||
"::/0": false,
|
||||
},
|
||||
},
|
||||
}
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
|
||||
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, lhControl, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
// Start the servers
|
||||
lhControl.Start()
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Stand up an ipv6 tunnel between me and them")
|
||||
assert.True(t, myVpnIpNet[1].Addr().Is6())
|
||||
assert.True(t, theirVpnIpNet[1].Addr().Is6())
|
||||
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
|
||||
|
||||
lhControl.Stop()
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
||||
unsafePrefix := "192.168.6.0/24"
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
|
||||
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
|
||||
myCfg := m{
|
||||
"tun": m{
|
||||
"unsafe_routes": []m{route},
|
||||
},
|
||||
}
|
||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
|
||||
t.Logf("my config %v", myConfig)
|
||||
// Put their info in our lighthouse
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
|
||||
spookyDest := netip.MustParseAddr("192.168.6.4")
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
||||
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||
|
||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||
|
||||
t.Log("Get their stage 1 packet so that we can play with it")
|
||||
stage1Packet := theirControl.GetFromUDP(true)
|
||||
|
||||
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
|
||||
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
|
||||
badPacket := stage1Packet.Copy()
|
||||
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
|
||||
myControl.InjectUDPPacket(badPacket)
|
||||
|
||||
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
|
||||
myControl.InjectUDPPacket(stage1Packet)
|
||||
|
||||
t.Log("Wait until we see my cached packet come through")
|
||||
myControl.WaitForType(1, 0, theirControl)
|
||||
|
||||
t.Log("Make sure our host infos are correct")
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||
|
||||
t.Log("Get that cached packet and make sure it looks right")
|
||||
myCachedPacket := theirControl.GetFromTun(true)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
|
||||
|
||||
//reply
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
|
||||
//wait for reply
|
||||
theirControl.WaitForType(1, 0, myControl)
|
||||
theirCachedPacket := myControl.GetFromTun(true)
|
||||
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
|
||||
|
||||
t.Log("Do a bidirectional tunnel test")
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
@@ -22,15 +22,14 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.yaml.in/yaml/v3"
|
||||
)
|
||||
|
||||
type m = map[string]any
|
||||
|
||||
// newSimpleServer creates a nebula instance with many assumptions
|
||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
var vpnNetworks []netip.Prefix
|
||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
@@ -56,7 +55,54 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
||||
}
|
||||
|
||||
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
|
||||
}
|
||||
|
||||
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
var vpnNetworks []netip.Prefix
|
||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||
}
|
||||
|
||||
if len(vpnNetworks) == 0 {
|
||||
panic("no vpn networks")
|
||||
}
|
||||
|
||||
firewallInbound := []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}}
|
||||
|
||||
var unsafeNetworks []netip.Prefix
|
||||
if sUnsafeNetworks != "" {
|
||||
firewallInbound = []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
"local_cidr": "0.0.0.0/0",
|
||||
}}
|
||||
|
||||
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
|
||||
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
unsafeNetworks = append(unsafeNetworks, x)
|
||||
}
|
||||
}
|
||||
|
||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
|
||||
|
||||
caB, err := caCrt.MarshalPEM()
|
||||
if err != nil {
|
||||
@@ -76,11 +122,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
"inbound": firewallInbound,
|
||||
},
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
@@ -129,6 +171,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
// newServer creates a nebula instance with fewer assumptions
|
||||
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
vpnNetworks := certs[len(certs)-1].Networks()
|
||||
|
||||
var udpAddr netip.AddrPort
|
||||
if vpnNetworks[0].Addr().Is4() {
|
||||
budpIp := vpnNetworks[0].Addr().As4()
|
||||
budpIp[1] -= 128
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||
} else {
|
||||
budpIp := vpnNetworks[0].Addr().As16()
|
||||
// beef for funsies
|
||||
budpIp[2] = 190
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
|
||||
caStr := ""
|
||||
for _, ca := range caCrt {
|
||||
x, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr += string(x)
|
||||
}
|
||||
certStr := ""
|
||||
for _, c := range certs {
|
||||
x, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
certStr += string(x)
|
||||
}
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": certStr,
|
||||
"key": string(key),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": m{
|
||||
"outbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
},
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": m{
|
||||
"host": udpAddr.Addr().String(),
|
||||
"port": udpAddr.Port(),
|
||||
},
|
||||
"logging": m{
|
||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
||||
"level": l.Level.String(),
|
||||
},
|
||||
"timers": m{
|
||||
"pending_deletion_interval": 2,
|
||||
"connection_alive_interval": 2,
|
||||
},
|
||||
}
|
||||
|
||||
if overrides != nil {
|
||||
final := m{}
|
||||
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mc = final
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
c := config.NewC(l)
|
||||
cStr := string(cb)
|
||||
c.LoadString(cStr)
|
||||
|
||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
type doneCb func()
|
||||
|
||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||
@@ -163,10 +308,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
|
||||
// Get both host infos
|
||||
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
||||
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
||||
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||
|
||||
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
||||
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||
|
||||
// Check that both vpn and real addr are correct
|
||||
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
||||
|
||||
@@ -4,12 +4,16 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestDropInactiveTunnels(t *testing.T) {
|
||||
@@ -55,3 +59,309 @@ func TestDropInactiveTunnels(t *testing.T) {
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertUpgrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCert2Pem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v2-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
if c == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
r.Logf("version %d", version)
|
||||
if version == cert.Version2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*10 {
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertDowngrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
|
||||
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCertPem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v1-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == cert.Version1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertMismatchCorrection(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == theirVersion {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("wtf")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCrossStackRelaysWork(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
||||
|
||||
//myVpnV4 := myVpnIpNet[0]
|
||||
myVpnV6 := myVpnIpNet[1]
|
||||
relayVpnV4 := relayVpnIpNet[0]
|
||||
relayVpnV6 := relayVpnIpNet[1]
|
||||
theirVpnV6 := theirVpnIpNet[0]
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
||||
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
||||
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
relayControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Trigger a handshake from me to them via the relay")
|
||||
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
r.Log("Assert the tunnel works")
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||
|
||||
t.Log("reply?")
|
||||
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
||||
p = r.RouteForAllUntilTxTun(myControl)
|
||||
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||
//t.Log("finish up")
|
||||
//myControl.Stop()
|
||||
//theirControl.Stop()
|
||||
//relayControl.Stop()
|
||||
}
|
||||
|
||||
33
firewall.go
33
firewall.go
@@ -417,8 +417,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||
var ErrUnknownNetworkType = errors.New("unknown network type")
|
||||
var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
|
||||
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
|
||||
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
|
||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
@@ -429,18 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure remote address matches nebula certificate
|
||||
if h.networks != nil {
|
||||
if !h.networks.Contains(fp.RemoteAddr) {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||
if h.networks == nil {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
switch nwType {
|
||||
case NetworkTypeVPN:
|
||||
break // nothing special
|
||||
case NetworkTypeVPNPeer:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||
case NetworkTypeUnsafe:
|
||||
break // nothing special, one day this may have different FW rules
|
||||
default:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrUnknownNetworkType //should never happen
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
|
||||
207
firewall_test.go
207
firewall_test.go
@@ -8,6 +8,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
@@ -104,13 +106,13 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
@@ -149,7 +151,8 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -174,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||
}
|
||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||
h.buildNetworks(myVpnNetworksTable, &c)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -226,6 +229,9 @@ func TestFirewall_DropV6(t *testing.T) {
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||
@@ -250,7 +256,7 @@ func TestFirewall_DropV6(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||
}
|
||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||
h.buildNetworks(myVpnNetworksTable, &c)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -453,6 +459,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -478,7 +486,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||
|
||||
c1 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -493,7 +501,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
peerCert: &c1,
|
||||
},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -510,6 +518,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -541,7 +551,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
||||
|
||||
c2 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -556,7 +566,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||
h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
|
||||
|
||||
c3 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -571,7 +581,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -597,6 +607,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
@@ -620,7 +632,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||
|
||||
// Test a remote address match
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
@@ -633,6 +645,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -659,7 +673,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -696,6 +710,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -717,7 +733,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
|
||||
@@ -1047,6 +1063,171 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||
assert.Equal(t, "group1", r.Group)
|
||||
}
|
||||
|
||||
type testcase struct {
|
||||
h *HostInfo
|
||||
p firewall.Packet
|
||||
c cert.Certificate
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
||||
t.Helper()
|
||||
cp := cert.NewCAPool()
|
||||
resetConntrack(fw)
|
||||
err := fw.Drop(c.p, true, c.h, cp, nil)
|
||||
if c.err == nil {
|
||||
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
||||
} else {
|
||||
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
||||
c1 := dummyCert{
|
||||
name: "host1",
|
||||
networks: theirPrefixes,
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &cert.CachedCertificate{
|
||||
Certificate: &c1,
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
},
|
||||
},
|
||||
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
|
||||
}
|
||||
for i := range theirPrefixes {
|
||||
h.vpnAddrs[i] = theirPrefixes[i].Addr()
|
||||
}
|
||||
h.buildNetworks(setup.myVpnNetworksTable, &c1)
|
||||
p := firewall.Packet{
|
||||
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
||||
RemoteAddr: theirPrefixes[0].Addr(),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
return testcase{
|
||||
h: &h,
|
||||
p: p,
|
||||
c: &c1,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
type testsetup struct {
|
||||
c dummyCert
|
||||
myVpnNetworksTable *bart.Lite
|
||||
fw *Firewall
|
||||
}
|
||||
|
||||
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
||||
c := dummyCert{
|
||||
name: "me",
|
||||
networks: myPrefixes,
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
|
||||
return newSetupFromCert(t, l, c)
|
||||
}
|
||||
|
||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
for _, prefix := range c.Networks() {
|
||||
myVpnNetworksTable.Insert(prefix)
|
||||
}
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
|
||||
return testsetup{
|
||||
c: c,
|
||||
fw: fw,
|
||||
myVpnNetworksTable: myVpnNetworksTable,
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
||||
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
||||
t.Run("allow inbound all matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("allow inbound local matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
|
||||
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("block inbound remote mismatched", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
|
||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("Block a vpn peer packet", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
twoPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
|
||||
}
|
||||
t.Run("allow inbound one matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, nil, twoPrefixes...)
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("block inbound multimismatch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
|
||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
|
||||
tc := buildTestCase(setup2, nil, twoPrefixes...)
|
||||
tc.p.RemoteAddr = twoPrefixes[1].Addr()
|
||||
tc.Test(t, setup2.fw)
|
||||
})
|
||||
t.Run("allow inbound unsafe route", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
|
||||
c := dummyCert{
|
||||
name: "me",
|
||||
networks: []netip.Prefix{myPrefix},
|
||||
unsafeNetworks: []netip.Prefix{unsafePrefix},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
unsafeSetup := newSetupFromCert(t, l, c)
|
||||
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
|
||||
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
||||
tc.err = ErrNoMatchingRule
|
||||
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
||||
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
|
||||
tc.err = nil
|
||||
tc.Test(t, unsafeSetup.fw) //should pass
|
||||
})
|
||||
}
|
||||
|
||||
type addRuleCall struct {
|
||||
incoming bool
|
||||
proto uint8
|
||||
|
||||
21
go.mod
21
go.mod
@@ -8,7 +8,7 @@ require (
|
||||
github.com/armon/go-radix v1.0.0
|
||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||
github.com/flynn/noise v1.1.0
|
||||
github.com/gaissmai/bart v0.25.0
|
||||
github.com/gaissmai/bart v0.26.0
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/kardianos/service v1.2.4
|
||||
@@ -22,18 +22,19 @@ require (
|
||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.43.0
|
||||
go.yaml.in/yaml/v3 v3.0.4
|
||||
golang.org/x/crypto v0.44.0
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||
golang.org/x/net v0.45.0
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/sync v0.18.0
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.org/x/term v0.37.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/protobuf v1.36.8
|
||||
google.golang.org/protobuf v1.36.10
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -49,6 +50,6 @@ require (
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
golang.org/x/mod v0.24.0 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.33.0 // indirect
|
||||
)
|
||||
|
||||
42
go.sum
42
go.sum
@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
|
||||
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
|
||||
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
@@ -155,13 +155,15 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
@@ -180,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -189,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -207,16 +209,16 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
@@ -230,8 +232,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
@@ -242,8 +244,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
@@ -257,5 +259,5 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g=
|
||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU=
|
||||
|
||||
159
handshake_ix.go
159
handshake_ix.go
@@ -2,7 +2,6 @@ package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
@@ -23,13 +22,17 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// If we're connecting to a v6 address we must use a v2 cert
|
||||
cs := f.pki.getCertState()
|
||||
v := cs.initiatingVersion
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||
v = hh.initiatingVersionOverride
|
||||
} else if v < cert.Version2 {
|
||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +51,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
return false
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||
@@ -103,6 +107,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", cs.initiatingVersion).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
return
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||
@@ -143,8 +148,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||
if err != nil {
|
||||
fp, err := rc.Fingerprint()
|
||||
if err != nil {
|
||||
fp, fperr := rc.Fingerprint()
|
||||
if fperr != nil {
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
@@ -163,16 +168,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if rc == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
Info("Unable to handshake with host due to missing certificate version")
|
||||
return
|
||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if myCertOtherVersion == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithError(err).WithFields(m{
|
||||
"udpAddr": addr,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert": remoteCert,
|
||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||
}
|
||||
} else {
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = myCertOtherVersion
|
||||
}
|
||||
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = rc
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
@@ -183,17 +191,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
return
|
||||
}
|
||||
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
certName := remoteCert.Certificate.Name()
|
||||
certVersion := remoteCert.Certificate.Version()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
vpnNetworks := remoteCert.Certificate.Networks()
|
||||
|
||||
for _, network := range remoteCert.Certificate.Networks() {
|
||||
vpnAddr := network.Addr()
|
||||
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||
anyVpnAddrsInCommon := false
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, network := range vpnNetworks {
|
||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
@@ -201,24 +209,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||
return
|
||||
}
|
||||
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
continue
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||
anyVpnAddrsInCommon = true
|
||||
}
|
||||
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
@@ -255,26 +249,30 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
},
|
||||
}
|
||||
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake message received")
|
||||
msgRxL := f.l.WithFields(m{
|
||||
"vpnAddrs": vpnAddrs,
|
||||
"udpAddr": addr,
|
||||
"certName": certName,
|
||||
"certVersion": certVersion,
|
||||
"fingerprint": fingerprint,
|
||||
"issuer": issuer,
|
||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
||||
"responderIndex": hs.Details.ResponderIndex,
|
||||
"remoteIndex": h.RemoteIndex,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
})
|
||||
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
} else {
|
||||
//todo warn if not lighthouse or relay?
|
||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||
}
|
||||
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||
if hs.Details.Cert == nil {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("certVersion", ci.myCert.Version()).
|
||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
return
|
||||
}
|
||||
@@ -332,7 +330,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||
hostinfo.SetRemote(addr)
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||
|
||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||
if err != nil {
|
||||
@@ -573,31 +571,22 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
}
|
||||
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
for _, network := range vpnNetworks {
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
vpnAddr := network.Addr()
|
||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
continue
|
||||
correctHostResponded := false
|
||||
anyVpnAddrsInCommon := false
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, network := range vpnNetworks {
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||
anyVpnAddrsInCommon = true
|
||||
}
|
||||
if hostinfo.vpnAddrs[0] == network.Addr() {
|
||||
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
||||
correctHostResponded = true
|
||||
}
|
||||
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return true
|
||||
}
|
||||
|
||||
// Ensure the right host responded
|
||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||
if !correctHostResponded {
|
||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||
WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
@@ -609,6 +598,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
|
||||
// Create a new hostinfo/handshake for the intended vpn ip
|
||||
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||
// Block the current used address
|
||||
newHH.hostinfo.remotes = hostinfo.remotes
|
||||
@@ -635,7 +625,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
@@ -643,12 +633,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("durationNs", duration).
|
||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
||||
Info("Handshake message received")
|
||||
WithField("sentCachedPackets", len(hh.packetStore))
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
} else {
|
||||
//todo warn if not lighthouse or relay?
|
||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||
}
|
||||
|
||||
// Build up the radix for the firewall if we have subnets in the cert
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||
|
||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||
f.handshakeManager.Complete(hostinfo, f)
|
||||
|
||||
@@ -68,11 +68,12 @@ type HandshakeManager struct {
|
||||
type HandshakeHostInfo struct {
|
||||
sync.Mutex
|
||||
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
|
||||
hostinfo *HostInfo
|
||||
}
|
||||
@@ -268,12 +269,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
||||
// Send a RelayRequest to all known Relay IP's
|
||||
for _, relay := range hostinfo.remotes.relays {
|
||||
// Don't relay to myself
|
||||
// Don't relay through the host I'm trying to connect to
|
||||
if relay == vpnIp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't relay through the host I'm trying to connect to
|
||||
// Don't relay to myself
|
||||
if hm.f.myVpnAddrsTable.Contains(relay) {
|
||||
continue
|
||||
}
|
||||
|
||||
40
hostmap.go
40
hostmap.go
@@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
||||
rs.relayForByIdx[idx] = r
|
||||
}
|
||||
|
||||
type NetworkType uint8
|
||||
|
||||
const (
|
||||
NetworkTypeUnknown NetworkType = iota
|
||||
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
||||
NetworkTypeVPN
|
||||
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
||||
NetworkTypeVPNPeer
|
||||
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||
NetworkTypeUnsafe
|
||||
)
|
||||
|
||||
type HostInfo struct {
|
||||
remote netip.AddrPort
|
||||
remotes *RemoteList
|
||||
@@ -225,8 +237,8 @@ type HostInfo struct {
|
||||
// vpn networks but were removed because they are not usable
|
||||
vpnAddrs []netip.Addr
|
||||
|
||||
// networks are both all vpn and unsafe networks assigned to this host
|
||||
networks *bart.Lite
|
||||
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
||||
networks *bart.Table[NetworkType]
|
||||
relayState RelayState
|
||||
|
||||
// HandshakePacket records the packets used to create this hostinfo
|
||||
@@ -730,20 +742,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||
// Simple case, no CIDRTree needed
|
||||
return
|
||||
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
|
||||
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
|
||||
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
|
||||
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
|
||||
return // Simple case, no BART needed
|
||||
}
|
||||
}
|
||||
|
||||
i.networks = new(bart.Lite)
|
||||
for _, network := range networks {
|
||||
i.networks = new(bart.Table[NetworkType])
|
||||
for _, network := range c.Networks() {
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
i.networks.Insert(nprefix)
|
||||
if myVpnNetworksTable.Contains(network.Addr()) {
|
||||
i.networks.Insert(nprefix, NetworkTypeVPN)
|
||||
} else {
|
||||
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
|
||||
}
|
||||
}
|
||||
|
||||
for _, network := range unsafeNetworks {
|
||||
i.networks.Insert(network)
|
||||
for _, network := range c.UnsafeNetworks() {
|
||||
i.networks.Insert(network, NetworkTypeUnsafe)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
17
inside.go
17
inside.go
@@ -33,7 +33,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||
// TUN device.
|
||||
if immediatelyForwardToSelf {
|
||||
if err := f.writeTun(q, packet); err != nil {
|
||||
_, err := f.readers[q].Write(packet)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to forward to tun")
|
||||
}
|
||||
}
|
||||
@@ -90,7 +91,8 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.writeTun(q, out); err != nil {
|
||||
_, err := f.readers[q].Write(out)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to write to tun")
|
||||
}
|
||||
}
|
||||
@@ -118,9 +120,10 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||
}
|
||||
|
||||
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
||||
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
|
||||
// it does not check if it is within our vpn networks!
|
||||
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||
f.getOrHandshakeNoRouting(vpnAddr, nil)
|
||||
f.handshakeManager.GetOrHandshake(vpnAddr, nil)
|
||||
}
|
||||
|
||||
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
||||
@@ -136,7 +139,6 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
|
||||
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
||||
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
|
||||
destinationAddr := fwPacket.RemoteAddr
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
||||
@@ -229,9 +231,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||
}
|
||||
|
||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
|
||||
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
|
||||
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||
})
|
||||
|
||||
|
||||
101
interface.go
101
interface.go
@@ -47,7 +47,6 @@ type InterfaceConfig struct {
|
||||
reQueryWait time.Duration
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
batchSize int
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
@@ -85,7 +84,6 @@ type Interface struct {
|
||||
version string
|
||||
|
||||
conntrackCacheTimeout time.Duration
|
||||
batchSize int
|
||||
|
||||
writers []udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
@@ -112,16 +110,6 @@ type EncWriter interface {
|
||||
GetCertState() *CertState
|
||||
}
|
||||
|
||||
// BatchReader is an interface for readers that support vectorized packet reading
|
||||
type BatchReader interface {
|
||||
BatchRead(buffers [][]byte, sizes []int) (int, error)
|
||||
}
|
||||
|
||||
// BatchWriter is an interface for writers that support vectorized packet writing
|
||||
type BatchWriter interface {
|
||||
BatchWrite([][]byte) (int, error)
|
||||
}
|
||||
|
||||
type sendRecvErrorConfig uint8
|
||||
|
||||
const (
|
||||
@@ -198,7 +186,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
relayManager: c.relayManager,
|
||||
connectionManager: c.connectionManager,
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
batchSize: c.batchSize,
|
||||
|
||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
messageMetrics: c.MessageMetrics,
|
||||
@@ -282,7 +269,7 @@ func (f *Interface) listenOut(i int) {
|
||||
plaintext := make([]byte, udp.MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12)
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||
@@ -292,16 +279,6 @@ func (f *Interface) listenOut(i int) {
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
// Check if reader supports batch operations
|
||||
if batchReader, ok := reader.(BatchReader); ok {
|
||||
err := f.listenInBatch(batchReader, i)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Fatal error in batch packet reader, exiting goroutine")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Fall back to single-packet mode
|
||||
packet := make([]byte, mtu)
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
@@ -316,85 +293,15 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
return
|
||||
}
|
||||
|
||||
f.l.WithError(err).Error("Fatal error while reading outbound packet, exiting goroutine")
|
||||
return
|
||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||
}
|
||||
}
|
||||
|
||||
// listenInBatch handles vectorized packet reading for improved performance
|
||||
func (f *Interface) listenInBatch(reader BatchReader, i int) error {
|
||||
// Allocate per-packet state and buffers for batch reading
|
||||
batchSize := f.batchSize
|
||||
if batchSize <= 0 {
|
||||
batchSize = 64 // Fallback to default if not configured
|
||||
}
|
||||
fwPackets := make([]*firewall.Packet, batchSize)
|
||||
outBuffers := make([][]byte, batchSize)
|
||||
nbBuffers := make([][]byte, batchSize)
|
||||
packets := make([][]byte, batchSize)
|
||||
sizes := make([]int, batchSize)
|
||||
|
||||
for j := 0; j < batchSize; j++ {
|
||||
fwPackets[j] = &firewall.Packet{}
|
||||
outBuffers[j] = make([]byte, mtu)
|
||||
nbBuffers[j] = make([]byte, 12)
|
||||
packets[j] = make([]byte, mtu)
|
||||
}
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
for {
|
||||
n, err := reader.BatchRead(packets, sizes)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("error while batch reading outbound packets: %w", err)
|
||||
}
|
||||
|
||||
// Process each packet in the batch
|
||||
cache := conntrackCache.Get(f.l)
|
||||
for idx := 0; idx < n; idx++ {
|
||||
if sizes[idx] > 0 {
|
||||
// Use modulo to reuse fw packet state if batch is larger than our pre-allocated state
|
||||
stateIdx := idx % len(fwPackets)
|
||||
f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeTunBatch attempts to write multiple packets to the TUN device using batch operations if supported
|
||||
func (f *Interface) writeTunBatch(q int, packets [][]byte) error {
|
||||
if len(packets) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the reader/writer supports batch operations
|
||||
if batchWriter, ok := f.readers[q].(BatchWriter); ok {
|
||||
_, err := batchWriter.BatchWrite(packets)
|
||||
return err
|
||||
}
|
||||
|
||||
// Fall back to writing packets individually
|
||||
for _, packet := range packets {
|
||||
if _, err := f.readers[q].Write(packet); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeTun writes a single packet to the TUN device
|
||||
func (f *Interface) writeTun(q int, packet []byte) error {
|
||||
_, err := f.readers[q].Write(packet)
|
||||
return err
|
||||
}
|
||||
|
||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||
c.RegisterReloadCallback(f.reloadFirewall)
|
||||
c.RegisterReloadCallback(f.reloadSendRecvError)
|
||||
|
||||
@@ -360,7 +360,8 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
||||
}
|
||||
|
||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
||||
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
||||
}
|
||||
out[i] = addr
|
||||
}
|
||||
@@ -431,7 +432,8 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
||||
}
|
||||
|
||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
|
||||
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
||||
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
||||
}
|
||||
|
||||
vals, ok := v.([]any)
|
||||
@@ -1337,12 +1339,19 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
}
|
||||
}
|
||||
|
||||
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
||||
for _, a := range n.Details.V4AddrPorts {
|
||||
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||
b := protoV4AddrPortToNetAddrPort(a)
|
||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||
punch(b, detailsVpnAddr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range n.Details.V6AddrPorts {
|
||||
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||
b := protoV6AddrPortToNetAddrPort(a)
|
||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||
punch(b, detailsVpnAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
"go.yaml.in/yaml/v3"
|
||||
)
|
||||
|
||||
func TestOldIPv4Only(t *testing.T) {
|
||||
|
||||
27
main.go
27
main.go
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -13,7 +15,7 @@ import (
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"gopkg.in/yaml.v3"
|
||||
"go.yaml.in/yaml/v3"
|
||||
)
|
||||
|
||||
type m = map[string]any
|
||||
@@ -27,6 +29,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
}()
|
||||
|
||||
if buildVersion == "" {
|
||||
buildVersion = moduleVersion()
|
||||
}
|
||||
|
||||
l := logger
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
@@ -75,7 +81,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
if c.GetBool("sshd.enabled", false) {
|
||||
sshStart, err = configSSH(l, ssh, c)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
||||
sshStart = nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,7 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||
punchy: punchy,
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
batchSize: c.GetInt("tun.batch_size", 64),
|
||||
l: l,
|
||||
}
|
||||
|
||||
@@ -296,3 +302,18 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
connManager.Start,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func moduleVersion() string {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, dep := range info.Deps {
|
||||
if dep.Path == "github.com/slackhq/nebula" {
|
||||
return strings.TrimPrefix(dep.Version, "v")
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -333,13 +333,12 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||
}
|
||||
|
||||
fp.Protocol = uint8(proto)
|
||||
ports := data[offset : offset+4]
|
||||
if incoming {
|
||||
fp.RemotePort = binary.BigEndian.Uint16(ports[0:2])
|
||||
fp.LocalPort = binary.BigEndian.Uint16(ports[2:4])
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||
} else {
|
||||
fp.LocalPort = binary.BigEndian.Uint16(ports[0:2])
|
||||
fp.RemotePort = binary.BigEndian.Uint16(ports[2:4])
|
||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||
}
|
||||
|
||||
fp.Fragment = false
|
||||
|
||||
@@ -3,6 +3,7 @@ package overlay
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -304,3 +305,29 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
|
||||
// Make sure o contains the lowest form of i
|
||||
if !o.Contains(i.IP.Mask(i.Mask)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Find the max ip in i
|
||||
ip4 := i.IP.To4()
|
||||
if ip4 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
last := make(net.IP, len(ip4))
|
||||
copy(last, ip4)
|
||||
for x := range ip4 {
|
||||
last[x] |= ^i.Mask[x]
|
||||
}
|
||||
|
||||
// Make sure o contains the max
|
||||
if !o.Contains(last) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -225,7 +225,6 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
||||
// no mtu
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Equal(t, 0, routes[0].MTU)
|
||||
|
||||
@@ -319,7 +318,7 @@ func Test_makeRouteTree(t *testing.T) {
|
||||
|
||||
ip, err = netip.ParseAddr("1.1.0.1")
|
||||
require.NoError(t, err)
|
||||
_, ok = routeTree.Lookup(ip)
|
||||
r, ok = routeTree.Lookup(ip)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -70,3 +72,51 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
pLen := 128
|
||||
if prefix.Addr().Is4() {
|
||||
pLen = 32
|
||||
}
|
||||
|
||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||
return addr
|
||||
}
|
||||
|
||||
func flipBytes(b []byte) []byte {
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] ^= 0xFF
|
||||
}
|
||||
return b
|
||||
}
|
||||
func orBytes(a []byte, b []byte) []byte {
|
||||
ret := make([]byte, len(a))
|
||||
for i := 0; i < len(a); i++ {
|
||||
ret[i] = a[i] | b[i]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
||||
broadcast, _ := netip.AddrFromSlice(
|
||||
orBytes(
|
||||
cidr.Addr().AsSlice(),
|
||||
flipBytes(prefixToMask(cidr).AsSlice()),
|
||||
),
|
||||
)
|
||||
return broadcast
|
||||
}
|
||||
|
||||
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
||||
for _, gateway := range gateways {
|
||||
if dest.Addr().Is4() && gateway.Addr().Is4() {
|
||||
return gateway, nil
|
||||
}
|
||||
|
||||
if dest.Addr().Is6() && gateway.Addr().Is6() {
|
||||
return gateway, nil
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
//go:build darwin && !ios && !e2e_testing
|
||||
// +build darwin,!ios,!e2e_testing
|
||||
//go:build !ios && !e2e_testing
|
||||
// +build !ios,!e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
@@ -8,27 +8,48 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
linkAddr *netroute.LinkAddr
|
||||
io.ReadWriteCloser
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
DefaultMTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
}
|
||||
|
||||
// ioctl structures for Darwin network configuration
|
||||
type ifReq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Flags uint16
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
const (
|
||||
_SIOCAIFADDR_IN6 = 2155899162
|
||||
_UTUN_OPT_IFNAME = 2
|
||||
_IN6_IFF_NODAD = 0x0020
|
||||
_IN6_IFF_SECURED = 0x0400
|
||||
utunControlName = "com.apple.net.utun_control"
|
||||
)
|
||||
|
||||
type ifreqMTU struct {
|
||||
Name [16]byte
|
||||
MTU int32
|
||||
@@ -58,61 +79,60 @@ type ifreqAlias6 struct {
|
||||
Lifetime addrLifetime
|
||||
}
|
||||
|
||||
const (
|
||||
_SIOCAIFADDR_IN6 = 2155899162
|
||||
_IN6_IFF_NODAD = 0x0020
|
||||
)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported on Darwin")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
name := c.GetString("tun.dev", "")
|
||||
deviceName := "utun"
|
||||
|
||||
// Parse device name to handle utun[0-9]+ format
|
||||
ifIndex := -1
|
||||
if name != "" && name != "utun" {
|
||||
ifIndex := -1
|
||||
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
|
||||
if err != nil || ifIndex < 0 {
|
||||
// NOTE: we don't make this error so we don't break existing
|
||||
// configs that set a name before it was used.
|
||||
l.Warn("interface name must be utun[0-9]+ on Darwin, ignoring")
|
||||
} else {
|
||||
deviceName = name
|
||||
ifIndex = -1
|
||||
}
|
||||
}
|
||||
|
||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||
|
||||
// Create WireGuard TUN device
|
||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||
return nil, fmt.Errorf("system socket: %v", err)
|
||||
}
|
||||
|
||||
// Get the actual device name
|
||||
actualName, err := tunDevice.Name()
|
||||
var ctlInfo = &unix.CtlInfo{}
|
||||
copy(ctlInfo.Name[:], utunControlName)
|
||||
|
||||
err = unix.IoctlCtlInfo(fd, ctlInfo)
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||
return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
|
||||
}
|
||||
|
||||
t := &wgTun{
|
||||
tunDevice: tunDevice,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MaxMTU: mtu,
|
||||
DefaultMTU: mtu,
|
||||
l: l,
|
||||
err = unix.Connect(fd, &unix.SockaddrCtl{
|
||||
ID: ctlInfo.Id,
|
||||
Unit: uint32(ifIndex) + 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SYS_CONNECT: %v", err)
|
||||
}
|
||||
|
||||
// Create Darwin-specific route manager
|
||||
t.routeManager = &tun{}
|
||||
name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SetNonblock: %v", err)
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
|
||||
Device: name,
|
||||
vpnNetworks: vpnNetworks,
|
||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -123,251 +143,215 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
})
|
||||
|
||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (rm *tun) Activate(t *wgTun) error {
|
||||
name, err := t.tunDevice.Name()
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
return t.ReadWriteCloser.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
devName := t.deviceBytes()
|
||||
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
unix.IPPROTO_IP,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device name: %w", err)
|
||||
return err
|
||||
}
|
||||
defer unix.Close(s)
|
||||
|
||||
fd := uintptr(s)
|
||||
|
||||
// Set the MTU on the device
|
||||
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
|
||||
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||
return fmt.Errorf("failed to set tun mtu: %v", err)
|
||||
}
|
||||
|
||||
// Set the MTU
|
||||
rm.SetMTU(t, t.MaxMTU)
|
||||
|
||||
// Add IP addresses
|
||||
for _, network := range t.vpnNetworks {
|
||||
if err := rm.addIP(t, name, network); err != nil {
|
||||
return err
|
||||
}
|
||||
// Get the device flags
|
||||
ifrf := ifReq{Name: devName}
|
||||
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to get tun flags: %s", err)
|
||||
}
|
||||
|
||||
// Bring up the interface using ioctl
|
||||
if err := rm.bringUpInterface(name); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %w", err)
|
||||
}
|
||||
|
||||
// Get the link address for routing
|
||||
linkAddr, err := getLinkAddr(name)
|
||||
linkAddr, err := getLinkAddr(t.Device)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get link address: %w", err)
|
||||
return err
|
||||
}
|
||||
if linkAddr == nil {
|
||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||
}
|
||||
rm.linkAddr = linkAddr
|
||||
t.linkAddr = linkAddr
|
||||
|
||||
// Set the routes
|
||||
if err := rm.AddRoutes(t, false); err != nil {
|
||||
for _, network := range t.vpnNetworks {
|
||||
if network.Addr().Is4() {
|
||||
err = t.activate4(network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = t.activate6(network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run the interface
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to run tun device: %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) activate4(network netip.Prefix) error {
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
unix.IPPROTO_IP,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(s)
|
||||
|
||||
ifr := ifreqAlias4{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: network.Addr().As4(),
|
||||
},
|
||||
DstAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: network.Addr().As4(),
|
||||
},
|
||||
MaskAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(network).As4(),
|
||||
},
|
||||
}
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun v4 address: %s", err)
|
||||
}
|
||||
|
||||
err = addRoute(network, t.linkAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) bringUpInterface(name string) error {
|
||||
// Open a socket for ioctl
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
|
||||
func (t *tun) activate6(network netip.Prefix) error {
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET6,
|
||||
unix.SOCK_DGRAM,
|
||||
unix.IPPROTO_IP,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create socket: %w", err)
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
// Get current flags
|
||||
var ifrf ifReq
|
||||
copy(ifrf.Name[:], name)
|
||||
|
||||
if err := ioctl(uintptr(fd), unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to get interface flags: %w", err)
|
||||
}
|
||||
|
||||
// Set IFF_UP and IFF_RUNNING flags
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||
|
||||
if err := ioctl(uintptr(fd), unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to set interface flags: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
||||
return
|
||||
}
|
||||
|
||||
// Open a socket for ioctl
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Failed to create socket for MTU set")
|
||||
return
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
// Prepare the ioctl request
|
||||
var ifr ifreqMTU
|
||||
copy(ifr.Name[:], name)
|
||||
ifr.MTU = int32(mtu)
|
||||
|
||||
// Set the MTU using ioctl
|
||||
if err := ioctl(uintptr(fd), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
t.l.WithError(err).Error("Failed to set tun mtu via ioctl")
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||
// On Darwin, routes are set via ifconfig and route commands
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
err := rm.addRoute(r.Cidr)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
t.l.WithField("route", r.Cidr).
|
||||
Warnf("unable to add unsafe_route, identical route already exists")
|
||||
} else {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
err := rm.delRoute(r.Cidr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||
// Darwin doesn't support multi-queue TUN devices in the same way as Linux
|
||||
// Return a reader that wraps the same device
|
||||
return &wgTunReader{
|
||||
parent: t,
|
||||
tunDevice: t.tunDevice,
|
||||
offset: 0,
|
||||
l: t.l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
|
||||
addr := network.Addr()
|
||||
|
||||
if addr.Is4() {
|
||||
return rm.addIPv4(name, network)
|
||||
} else {
|
||||
return rm.addIPv6(name, network)
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *tun) addIPv4(name string, network netip.Prefix) error {
|
||||
// Open an IPv4 socket for ioctl
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IPv4 socket: %w", err)
|
||||
return err
|
||||
}
|
||||
defer unix.Close(s)
|
||||
|
||||
var ifr ifreqAlias4
|
||||
copy(ifr.Name[:], name)
|
||||
|
||||
// Set the address
|
||||
ifr.Addr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: network.Addr().As4(),
|
||||
ifr := ifreqAlias6{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: network.Addr().As16(),
|
||||
},
|
||||
PrefixMask: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(network).As16(),
|
||||
},
|
||||
Lifetime: addrLifetime{
|
||||
// never expires
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
},
|
||||
Flags: _IN6_IFF_NODAD,
|
||||
}
|
||||
|
||||
// Set the destination address (same as address for point-to-point)
|
||||
ifr.DstAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: network.Addr().As4(),
|
||||
}
|
||||
|
||||
// Set the netmask
|
||||
ifr.MaskAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(network).As4(),
|
||||
}
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set IPv4 address via ioctl: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) addIPv6(name string, network netip.Prefix) error {
|
||||
// Open an IPv6 socket for ioctl
|
||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IPv6 socket: %w", err)
|
||||
}
|
||||
defer unix.Close(s)
|
||||
|
||||
var ifr ifreqAlias6
|
||||
copy(ifr.Name[:], name)
|
||||
|
||||
// Set the address
|
||||
ifr.Addr = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: network.Addr().As16(),
|
||||
}
|
||||
|
||||
// Set the prefix mask
|
||||
ifr.PrefixMask = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(network).As16(),
|
||||
}
|
||||
|
||||
// Set lifetime (never expires)
|
||||
ifr.Lifetime = addrLifetime{
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
}
|
||||
|
||||
// Set flags (no DAD - Duplicate Address Detection)
|
||||
ifr.Flags = _IN6_IFF_NODAD
|
||||
|
||||
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set IPv6 address via ioctl: %w", err)
|
||||
return fmt.Errorf("failed to set tun address: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||
}
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, ok := t.routeTree.Load().Lookup(ip)
|
||||
if ok {
|
||||
return r
|
||||
}
|
||||
return routing.Gateways{}
|
||||
}
|
||||
|
||||
// Get the LinkAddr for the interface of the given name
|
||||
// Is there an easier way to fetch this when we create the interface?
|
||||
// Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
|
||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||
if err != nil {
|
||||
@@ -393,7 +377,53 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (rm *tun) addRoute(prefix netip.Prefix) error {
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
for _, r := range routes {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
err := addRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
t.l.WithField("route", r.Cidr).
|
||||
Warnf("unable to add unsafe_route, identical route already exists")
|
||||
} else {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
err := delRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
@@ -411,13 +441,13 @@ func (rm *tun) addRoute(prefix netip.Prefix) error {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: rm.linkAddr,
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: rm.linkAddr,
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,7 +464,7 @@ func (rm *tun) addRoute(prefix netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) delRoute(prefix netip.Prefix) error {
|
||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
@@ -451,13 +481,13 @@ func (rm *tun) delRoute(prefix netip.Prefix) error {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: rm.linkAddr,
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: rm.linkAddr,
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -465,7 +495,6 @@ func (rm *tun) delRoute(prefix netip.Prefix) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
@@ -474,34 +503,52 @@ func (rm *tun) delRoute(prefix netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ioctl(a1, a2, a3 uintptr) error {
|
||||
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
return nil
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.ReadWriteCloser.Read(buf)
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
bits := prefix.Bits()
|
||||
if prefix.Addr().Is4() {
|
||||
// Create IPv4 netmask from prefix length
|
||||
mask := ^uint32(0) << (32 - bits)
|
||||
return netip.AddrFrom4([4]byte{
|
||||
byte(mask >> 24),
|
||||
byte(mask >> 16),
|
||||
byte(mask >> 8),
|
||||
byte(mask),
|
||||
})
|
||||
} else {
|
||||
// Create IPv6 netmask from prefix length
|
||||
var mask [16]byte
|
||||
for i := 0; i < bits/8; i++ {
|
||||
mask[i] = 0xff
|
||||
}
|
||||
if bits%8 != 0 {
|
||||
mask[bits/8] = ^byte(0) << (8 - bits%8)
|
||||
}
|
||||
return netip.AddrFrom16(mask)
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
if cap(buf) < len(from)+4 {
|
||||
buf = make([]byte, len(from)+4)
|
||||
t.out = buf
|
||||
}
|
||||
buf = buf[:len(from)+4]
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
buf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
buf[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
copy(buf[4:], from)
|
||||
|
||||
n, err := t.ReadWriteCloser.Write(buf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
@@ -1,77 +1,284 @@
|
||||
//go:build freebsd && !e2e_testing
|
||||
// +build freebsd,!e2e_testing
|
||||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type tun struct{}
|
||||
const (
|
||||
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
||||
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
||||
FIODGNAME = 0x80106678
|
||||
TUNSIFMODE = 0x8004745e
|
||||
TUNSIFHEAD = 0x80047460
|
||||
OSIOCAIFADDR_IN6 = 0x8088691b
|
||||
IN6_IFF_NODAD = 0x0020
|
||||
)
|
||||
|
||||
type fiodgnameArg struct {
|
||||
length int32
|
||||
pad [4]byte
|
||||
buf unsafe.Pointer
|
||||
}
|
||||
|
||||
// ifreqRename is used for renaming network interfaces on FreeBSD
|
||||
type ifreqRename struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Data uintptr
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported on FreeBSD")
|
||||
type ifreqDestroy struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
pad [16]byte
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
||||
deviceName := c.GetString("tun.dev", "tun")
|
||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||
type ifReq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Flags uint16
|
||||
}
|
||||
|
||||
// Create WireGuard TUN device
|
||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||
type ifreqMTU struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
MTU int32
|
||||
}
|
||||
|
||||
type addrLifetime struct {
|
||||
Expire uint64
|
||||
Preferred uint64
|
||||
Vltime uint32
|
||||
Pltime uint32
|
||||
}
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
VHid uint32
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime addrLifetime
|
||||
VHid uint32
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
devFd int
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
||||
if t.devFd < 0 {
|
||||
return -1, syscall.EINVAL
|
||||
}
|
||||
|
||||
// Get the actual device name
|
||||
actualName, err := tunDevice.Name()
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
head := make([]byte, 4)
|
||||
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&to[0], uint64(len(to))},
|
||||
}
|
||||
|
||||
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
|
||||
var err error
|
||||
if errno != 0 {
|
||||
err = syscall.Errno(errno)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
// fix bytes read number to exclude header
|
||||
bytesRead := int(n)
|
||||
if bytesRead < 0 {
|
||||
return bytesRead, err
|
||||
} else if bytesRead < 4 {
|
||||
return 0, err
|
||||
} else {
|
||||
return bytesRead - 4, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
||||
if t.devFd < 0 {
|
||||
return -1, syscall.EINVAL
|
||||
}
|
||||
|
||||
if len(from) <= 1 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
ipVer := from[0] >> 4
|
||||
var head []byte
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
if ipVer == 4 {
|
||||
head = []byte{0, 0, 0, syscall.AF_INET}
|
||||
} else if ipVer == 6 {
|
||||
head = []byte{0, 0, 0, syscall.AF_INET6}
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&from[0], uint64(len(from))},
|
||||
}
|
||||
|
||||
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
|
||||
var err error
|
||||
if errno != 0 {
|
||||
err = syscall.Errno(errno)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
|
||||
return int(n) - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.devFd >= 0 {
|
||||
err := syscall.Close(t.devFd)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Error closing device")
|
||||
}
|
||||
t.devFd = -1
|
||||
|
||||
c := make(chan struct{})
|
||||
go func() {
|
||||
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
||||
defer close(c)
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err == nil {
|
||||
defer syscall.Close(s)
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
}
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Error destroying tunnel")
|
||||
}
|
||||
}()
|
||||
|
||||
// wait up to 1 second so we start blocking at the ioctl
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var fd int
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName != "" {
|
||||
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
||||
}
|
||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||
// If the device doesn't already exist, request a new one and rename it
|
||||
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
||||
}
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the name of the interface
|
||||
var name [16]byte
|
||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||
|
||||
if ctrlErr == nil {
|
||||
// set broadcast mode and multicast
|
||||
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
||||
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
||||
}
|
||||
|
||||
if ctrlErr == nil {
|
||||
// turn on link-layer mode, to support ipv6
|
||||
ifhead := uint32(1)
|
||||
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
||||
}
|
||||
|
||||
if ctrlErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ifName := string(bytes.TrimRight(name[:], "\x00"))
|
||||
if deviceName == "" {
|
||||
deviceName = ifName
|
||||
}
|
||||
|
||||
// If the name doesn't match the desired interface name, rename it now
|
||||
if actualName != deviceName && deviceName != "" && deviceName != "tun" {
|
||||
if err := renameInterface(actualName, deviceName); err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, fmt.Errorf("failed to rename interface from %s to %s: %w", actualName, deviceName, err)
|
||||
if ifName != deviceName {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
actualName = deviceName
|
||||
defer syscall.Close(s)
|
||||
|
||||
fd := uintptr(s)
|
||||
|
||||
var fromName [16]byte
|
||||
var toName [16]byte
|
||||
copy(fromName[:], ifName)
|
||||
copy(toName[:], deviceName)
|
||||
|
||||
ifrr := ifreqRename{
|
||||
Name: fromName,
|
||||
Data: uintptr(unsafe.Pointer(&toName)),
|
||||
}
|
||||
|
||||
// Set the device name
|
||||
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
|
||||
}
|
||||
|
||||
t := &wgTun{
|
||||
tunDevice: tunDevice,
|
||||
t := &tun{
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MaxMTU: mtu,
|
||||
DefaultMTU: mtu,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
devFd: fd,
|
||||
}
|
||||
|
||||
// Create FreeBSD-specific route manager
|
||||
t.routeManager = &tun{}
|
||||
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -82,86 +289,180 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
})
|
||||
|
||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (rm *tun) Activate(t *wgTun) error {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device name: %w", err)
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
if cidr.Addr().Is4() {
|
||||
ifr := ifreqAlias4{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
},
|
||||
DstAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: getBroadcast(cidr).As4(),
|
||||
},
|
||||
MaskAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
},
|
||||
VHid: 0,
|
||||
}
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set the MTU
|
||||
rm.SetMTU(t, t.MaxMTU)
|
||||
if cidr.Addr().Is6() {
|
||||
ifr := ifreqAlias6{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
},
|
||||
PrefixMask: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
},
|
||||
Lifetime: addrLifetime{
|
||||
Expire: 0,
|
||||
Preferred: 0,
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
},
|
||||
Flags: IN6_IFF_NODAD,
|
||||
}
|
||||
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
// Add IP addresses
|
||||
for _, network := range t.vpnNetworks {
|
||||
if err := rm.addIP(t, name, network); err != nil {
|
||||
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
// Setup our default MTU
|
||||
err := t.setMTU()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
linkAddr, err := getLinkAddr(t.Device)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if linkAddr == nil {
|
||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||
}
|
||||
t.linkAddr = linkAddr
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
if err := runCommandBSD("ifconfig", name, "up"); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %w", err)
|
||||
}
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
// Set the routes
|
||||
if err := rm.AddRoutes(t, false); err != nil {
|
||||
func (t *tun) setMTU() error {
|
||||
// Set the MTU on the device
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
||||
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||
}
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
||||
return
|
||||
}
|
||||
|
||||
if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
|
||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||
}
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||
// On FreeBSD, routes are set via ifconfig and route commands
|
||||
return nil
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device name: %w", err)
|
||||
}
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
// Add route using route command
|
||||
args := []string{"add"}
|
||||
|
||||
if r.Cidr.Addr().Is6() {
|
||||
args = append(args, "-inet6")
|
||||
} else {
|
||||
args = append(args, "-inet")
|
||||
}
|
||||
|
||||
args = append(args, r.Cidr.String(), "-interface", name)
|
||||
|
||||
if r.Metric > 0 {
|
||||
// FreeBSD doesn't support route metrics directly like Linux
|
||||
t.l.WithField("route", r).Warn("Route metrics are not fully supported on FreeBSD")
|
||||
}
|
||||
|
||||
err := runCommandBSD("route", args...)
|
||||
err := addRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
@@ -177,99 +478,142 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Failed to get device name for route removal")
|
||||
return
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
args := []string{"delete"}
|
||||
|
||||
if r.Cidr.Addr().Is6() {
|
||||
args = append(args, "-inet6")
|
||||
} else {
|
||||
args = append(args, "-inet")
|
||||
}
|
||||
|
||||
args = append(args, r.Cidr.String(), "-interface", name)
|
||||
|
||||
err := runCommandBSD("route", args...)
|
||||
err := delRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||
// FreeBSD doesn't support multi-queue TUN devices in the same way as Linux
|
||||
// Return a reader that wraps the same device
|
||||
return &wgTunReader{
|
||||
parent: t,
|
||||
tunDevice: t.tunDevice,
|
||||
offset: 0,
|
||||
l: t.l,
|
||||
}, nil
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
|
||||
addr := network.Addr()
|
||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
if addr.Is4() {
|
||||
// For IPv4: ifconfig tun0 10.0.0.1/24
|
||||
if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
|
||||
return fmt.Errorf("failed to add IPv4 address: %w", err)
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
// For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
|
||||
if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
|
||||
return fmt.Errorf("failed to add IPv6 address: %w", err)
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runCommandBSD(name string, args ...string) error {
|
||||
cmd := exec.Command(name, args...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func renameInterface(fromName, toName string) error {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create socket: %w", err)
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
fd := uintptr(s)
|
||||
|
||||
var fromNameBytes [unix.IFNAMSIZ]byte
|
||||
var toNameBytes [unix.IFNAMSIZ]byte
|
||||
copy(fromNameBytes[:], fromName)
|
||||
copy(toNameBytes[:], toName)
|
||||
|
||||
ifrr := ifreqRename{
|
||||
Name: fromNameBytes,
|
||||
Data: uintptr(unsafe.Pointer(&toNameBytes)),
|
||||
}
|
||||
|
||||
// Set the device name using SIOCSIFNAME ioctl
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
|
||||
if errno != 0 {
|
||||
return fmt.Errorf("SIOCSIFNAME ioctl failed: %w", errno)
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
fmt.Println("DOING CHANGE")
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLinkAddr Gets the link address for the interface of the given name
|
||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
switch m := m.(type) {
|
||||
case *netroute.InterfaceMessage:
|
||||
if m.Name == name {
|
||||
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
||||
if ok {
|
||||
return sa, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
// +build linux,!android,!e2e_testing
|
||||
//go:build !android && !e2e_testing
|
||||
// +build !android,!e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
@@ -9,105 +9,133 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
txQueueLen int
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
routeChan chan struct{}
|
||||
useSystemRoutes bool
|
||||
useSystemRoutesBufferSize int
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*wgTun, error) {
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||
|
||||
// Create WireGuard TUN device
|
||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||
}
|
||||
|
||||
// Get the actual device name
|
||||
actualName, err := tunDevice.Name()
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||
}
|
||||
|
||||
t := &wgTun{
|
||||
tunDevice: tunDevice,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MaxMTU: mtu,
|
||||
DefaultMTU: mtu,
|
||||
l: l,
|
||||
}
|
||||
|
||||
// Create Linux-specific route manager
|
||||
routeManager := &tun{
|
||||
txQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
}
|
||||
t.routeManager = routeManager
|
||||
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||
|
||||
return t, nil
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*wgTun, error) {
|
||||
// Create TUN device from file descriptor
|
||||
type ifReq struct {
|
||||
Name [16]byte
|
||||
Flags uint16
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
type ifreqMTU struct {
|
||||
Name [16]byte
|
||||
MTU int32
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
type ifreqQLEN struct {
|
||||
Name [16]byte
|
||||
Value int32
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||
tunDevice, err := wgtun.CreateTUNFromFile(file, mtu)
|
||||
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TUN device from fd: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &wgTun{
|
||||
tunDevice: tunDevice,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MaxMTU: mtu,
|
||||
DefaultMTU: mtu,
|
||||
l: l,
|
||||
t.Device = "tun0"
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
if os.IsNotExist(err) {
|
||||
err = os.MkdirAll("/dev/net", 0755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||
}
|
||||
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||
}
|
||||
|
||||
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Create Linux-specific route manager
|
||||
routeManager := &tun{
|
||||
txQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||
if multiqueue {
|
||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||
}
|
||||
copy(req.Name[:], c.GetString("tun.dev", ""))
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.Device = name
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
vpnNetworks: vpnNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
l: l,
|
||||
}
|
||||
t.routeManager = routeManager
|
||||
|
||||
err = t.reload(c, true)
|
||||
err := t.reload(c, true)
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -121,105 +149,273 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (rm *tun) Activate(t *wgTun) error {
|
||||
name, err := t.tunDevice.Name()
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device name: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if t.routeManager.useSystemRoutes {
|
||||
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldDefaultMTU := t.DefaultMTU
|
||||
oldMaxMTU := t.MaxMTU
|
||||
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
|
||||
newMaxMTU := newDefaultMTU
|
||||
for i, r := range routes {
|
||||
if r.MTU == 0 {
|
||||
routes[i].MTU = newDefaultMTU
|
||||
}
|
||||
|
||||
if r.MTU > t.MaxMTU {
|
||||
newMaxMTU = r.MTU
|
||||
}
|
||||
}
|
||||
|
||||
t.MaxMTU = newMaxMTU
|
||||
t.DefaultMTU = newDefaultMTU
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
if oldMaxMTU != newMaxMTU {
|
||||
t.setMTU()
|
||||
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
||||
}
|
||||
|
||||
if oldDefaultMTU != newDefaultMTU {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.setDefaultRoute(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
t.l.Warn(err)
|
||||
} else {
|
||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// This should never be called since addRoutes should log its own errors in a reload condition
|
||||
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||
copy(req.Name[:], t.Device)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
var nn int
|
||||
maximum := len(b)
|
||||
|
||||
for {
|
||||
n, err := unix.Write(t.fd, b[nn:maximum])
|
||||
if n > 0 {
|
||||
nn += n
|
||||
}
|
||||
if nn == len(b) {
|
||||
return nn, err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nn, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return nn, io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
|
||||
for i := range al {
|
||||
if al[i].Equal(x) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
|
||||
func (t *tun) addIPs(link netlink.Link) error {
|
||||
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
|
||||
for i := range t.vpnNetworks {
|
||||
newAddrs[i] = &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: t.vpnNetworks[i].Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
|
||||
},
|
||||
Label: t.vpnNetworks[i].Addr().Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
//add all new addresses
|
||||
for i := range newAddrs {
|
||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
//iterate over remainder, remove whoever shouldn't be there
|
||||
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tun address list: %s", err)
|
||||
}
|
||||
|
||||
for i := range al {
|
||||
if hasNetlinkAddr(newAddrs, al[i]) {
|
||||
continue
|
||||
}
|
||||
err = netlink.AddrDel(link, &al[i])
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("failed to remove address from tun address list")
|
||||
} else {
|
||||
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
devName := t.deviceBytes()
|
||||
|
||||
if t.useSystemRoutes {
|
||||
t.watchRoutes()
|
||||
}
|
||||
|
||||
// Get the netlink device
|
||||
link, err := netlink.LinkByName(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tun device link: %s", err)
|
||||
}
|
||||
|
||||
rm.deviceIndex = link.Attrs().Index
|
||||
|
||||
// Open socket for ioctl operations
|
||||
s, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
|
||||
unix.SOCK_DGRAM,
|
||||
unix.IPPROTO_IP,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rm.ioctlFd = uintptr(s)
|
||||
t.ioctlFd = uintptr(s)
|
||||
|
||||
rm.SetMTU(t, t.MaxMTU)
|
||||
// Set the device name
|
||||
ifrf := ifReq{Name: devName}
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to set tun device name: %s", err)
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(t.Device)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tun device link: %s", err)
|
||||
}
|
||||
|
||||
t.deviceIndex = link.Attrs().Index
|
||||
|
||||
// Setup our default MTU
|
||||
t.setMTU()
|
||||
|
||||
// Set the transmit queue length
|
||||
devName := deviceBytes(name)
|
||||
ifrq := ifreqQLEN{Name: devName, Value: int32(rm.txQueueLen)}
|
||||
if err = ioctl(t.routeManager.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
}
|
||||
|
||||
// Disable IPv6 link-local address generation
|
||||
const modeNone = 1
|
||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
||||
}
|
||||
|
||||
// Add IP addresses
|
||||
if err = t.routeManager.addIPs(t, link); err != nil {
|
||||
if err = t.addIPs(link); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
if err = netlink.LinkSetUp(link); err != nil {
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
||||
}
|
||||
|
||||
// Set route MTU
|
||||
//set route MTU
|
||||
for i := range t.vpnNetworks {
|
||||
if err = t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i]); err != nil {
|
||||
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
|
||||
return fmt.Errorf("failed to set default route MTU: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the routes
|
||||
if err = t.routeManager.AddRoutes(t, false); err != nil {
|
||||
if err = t.addRoutes(false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run the interface
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to run tun device: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
||||
return
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(name)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Failed to get link for MTU set")
|
||||
return
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetMTU(link, mtu); err != nil {
|
||||
func (t *tun) setMTU() {
|
||||
// Set the MTU on the device
|
||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
|
||||
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
||||
dr := &net.IPNet{
|
||||
IP: cidr.Masked().Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.routeManager.deviceIndex,
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: dr,
|
||||
MTU: t.DefaultMTU,
|
||||
AdvMSS: advMSS(Route{}, t.DefaultMTU, t.MaxMTU),
|
||||
AdvMSS: t.advMSS(Route{}),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Src: net.IP(cidr.Addr().AsSlice()),
|
||||
Protocol: unix.RTPROT_KERNEL,
|
||||
@@ -229,7 +425,7 @@ func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||
err := netlink.RouteReplace(&nr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
|
||||
// Retry twice more
|
||||
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
|
||||
for i := 0; i < 2; i++ {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = netlink.RouteReplace(&nr)
|
||||
@@ -247,7 +443,8 @@ func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
// Path routes
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
@@ -260,10 +457,10 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.routeManager.deviceIndex,
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: dr,
|
||||
MTU: r.MTU,
|
||||
AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
|
||||
AdvMSS: t.advMSS(r),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
}
|
||||
|
||||
@@ -287,7 +484,7 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||
func (t *tun) removeRoutes(routes []Route) {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
@@ -299,10 +496,10 @@ func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.routeManager.deviceIndex,
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: dr,
|
||||
MTU: r.MTU,
|
||||
AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
|
||||
AdvMSS: t.advMSS(r),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
}
|
||||
|
||||
@@ -319,105 +516,30 @@ func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||
// For Linux with WireGuard TUN, we can reuse the same device
|
||||
// The vectorized I/O will handle batching
|
||||
return &wgTunReader{
|
||||
parent: t,
|
||||
tunDevice: t.tunDevice,
|
||||
offset: 0,
|
||||
l: t.l,
|
||||
}, nil
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func deviceBytes(name string) [16]byte {
|
||||
var o [16]byte
|
||||
for i, c := range name {
|
||||
if i >= 16 {
|
||||
break
|
||||
}
|
||||
o[i] = byte(c)
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
func advMSS(r Route, defaultMTU, maxMTU int) int {
|
||||
func (t *tun) advMSS(r Route) int {
|
||||
mtu := r.MTU
|
||||
if r.MTU == 0 {
|
||||
mtu = defaultMTU
|
||||
mtu = t.DefaultMTU
|
||||
}
|
||||
|
||||
// We only need to set advmss if the route MTU does not match the device MTU
|
||||
if mtu != maxMTU {
|
||||
if mtu != t.MaxMTU {
|
||||
return mtu - 40
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type ifreqQLEN struct {
|
||||
Name [16]byte
|
||||
Value int32
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
|
||||
for i := range al {
|
||||
if al[i].Equal(x) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rm *tun) addIPs(t *wgTun, link netlink.Link) error {
|
||||
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
|
||||
for i := range t.vpnNetworks {
|
||||
newAddrs[i] = &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: t.vpnNetworks[i].Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
|
||||
},
|
||||
Label: t.vpnNetworks[i].Addr().Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
// Add all new addresses
|
||||
for i := range newAddrs {
|
||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate over remainder, remove whoever shouldn't be there
|
||||
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tun address list: %s", err)
|
||||
}
|
||||
|
||||
for i := range al {
|
||||
if hasNetlinkAddr(newAddrs, al[i]) {
|
||||
continue
|
||||
}
|
||||
err = netlink.AddrDel(link, &al[i])
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("failed to remove address from tun address list")
|
||||
} else {
|
||||
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// watchRoutes monitors system route changes
|
||||
func (t *wgTun) watchRoutes() {
|
||||
|
||||
func (t *tun) watchRoutes() {
|
||||
rch := make(chan netlink.RouteUpdate)
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
netlinkOptions := netlink.RouteSubscribeOptions{
|
||||
ReceiveBufferSize: t.routeManager.useSystemRoutesBufferSize,
|
||||
ReceiveBufferForceSize: t.routeManager.useSystemRoutesBufferSize != 0,
|
||||
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
||||
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
||||
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
||||
}
|
||||
|
||||
@@ -435,19 +557,87 @@ func (t *wgTun) watchRoutes() {
|
||||
if ok {
|
||||
t.updateRoutes(r)
|
||||
} else {
|
||||
// may be should do something here as
|
||||
// netlink stops sending updates
|
||||
return
|
||||
}
|
||||
case <-doneChan:
|
||||
// netlink.RouteSubscriber will close the rch for us
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
|
||||
gateways := t.getGatewaysFromRoute(&r.Route, t.routeManager.deviceIndex)
|
||||
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
||||
withinNetworks := false
|
||||
for i := range t.vpnNetworks {
|
||||
if t.vpnNetworks[i].Contains(gwAddr) {
|
||||
withinNetworks = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return withinNetworks
|
||||
}
|
||||
|
||||
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
||||
|
||||
var gateways routing.Gateways
|
||||
|
||||
link, err := netlink.LinkByName(t.Device)
|
||||
if err != nil {
|
||||
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
|
||||
return gateways
|
||||
}
|
||||
|
||||
// If this route is relevant to our interface and there is a gateway then add it
|
||||
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||
} else {
|
||||
gwAddr = gwAddr.Unmap()
|
||||
|
||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||
} else {
|
||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range r.MultiPath {
|
||||
// If this route is relevant to our interface and there is a gateway then add it
|
||||
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
||||
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
||||
} else {
|
||||
gwAddr = gwAddr.Unmap()
|
||||
|
||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||
} else {
|
||||
// p.Hops+1 = weight of the route
|
||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
routing.CalculateBucketsForGateways(gateways)
|
||||
return gateways
|
||||
}
|
||||
|
||||
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
|
||||
gateways := t.getGatewaysFromRoute(&r.Route)
|
||||
|
||||
if len(gateways) == 0 {
|
||||
// No gateways relevant to our network, no routing changes required.
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
||||
return
|
||||
}
|
||||
@@ -471,6 +661,7 @@ func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
|
||||
if r.Type == unix.RTM_NEWROUTE {
|
||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
||||
newTree.Insert(dst, gateways)
|
||||
|
||||
} else {
|
||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
||||
newTree.Delete(dst)
|
||||
@@ -478,71 +669,18 @@ func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
|
||||
t.routeTree.Store(newTree)
|
||||
}
|
||||
|
||||
func (t *wgTun) getGatewaysFromRoute(r *netlink.Route, deviceIndex int) routing.Gateways {
|
||||
var gateways routing.Gateways
|
||||
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
t.l.Error("Ignoring route update: failed to get device name")
|
||||
return gateways
|
||||
func (t *tun) Close() error {
|
||||
if t.routeChan != nil {
|
||||
close(t.routeChan)
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(name)
|
||||
if err != nil {
|
||||
t.l.WithField("DeviceName", name).Error("Ignoring route update: failed to get link by name")
|
||||
return gateways
|
||||
if t.ReadWriteCloser != nil {
|
||||
_ = t.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
// If this route is relevant to our interface and there is a gateway then add it
|
||||
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||
} else {
|
||||
gwAddr = gwAddr.Unmap()
|
||||
|
||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||
} else {
|
||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||
}
|
||||
}
|
||||
if t.ioctlFd > 0 {
|
||||
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
||||
}
|
||||
|
||||
for _, p := range r.MultiPath {
|
||||
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
||||
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
||||
} else {
|
||||
gwAddr = gwAddr.Unmap()
|
||||
|
||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||
} else {
|
||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
routing.CalculateBucketsForGateways(gateways)
|
||||
return gateways
|
||||
}
|
||||
|
||||
func (t *wgTun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
||||
for i := range t.vpnNetworks {
|
||||
if t.vpnNetworks[i].Contains(gwAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ioctl(a1, a2, a3 uintptr) error {
|
||||
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,27 +6,26 @@ package overlay
|
||||
import "testing"
|
||||
|
||||
var runAdvMSSTests = []struct {
|
||||
name string
|
||||
defaultMTU int
|
||||
maxMTU int
|
||||
r Route
|
||||
expected int
|
||||
name string
|
||||
tun *tun
|
||||
r Route
|
||||
expected int
|
||||
}{
|
||||
// Standard case, default MTU is the device max MTU
|
||||
{"default", 1440, 1440, Route{}, 0},
|
||||
{"default-min", 1440, 1440, Route{MTU: 1440}, 0},
|
||||
{"default-low", 1440, 1440, Route{MTU: 1200}, 1160},
|
||||
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
|
||||
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
|
||||
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
|
||||
|
||||
// Case where we have a route MTU set higher than the default
|
||||
{"route", 1440, 8941, Route{}, 1400},
|
||||
{"route-min", 1440, 8941, Route{MTU: 1440}, 1400},
|
||||
{"route-high", 1440, 8941, Route{MTU: 8941}, 0},
|
||||
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
|
||||
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
|
||||
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
|
||||
}
|
||||
|
||||
func TestTunAdvMSS(t *testing.T) {
|
||||
for _, tt := range runAdvMSSTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := advMSS(tt.r, tt.defaultMTU, tt.maxMTU)
|
||||
o := tt.tun.advMSS(tt.r)
|
||||
if o != tt.expected {
|
||||
t.Errorf("got %d, want %d", o, tt.expected)
|
||||
}
|
||||
|
||||
@@ -547,41 +547,3 @@ func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ioctl(a1, a2, a3 uintptr) error {
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
bits := prefix.Bits()
|
||||
if prefix.Addr().Is4() {
|
||||
mask := ^uint32(0) << (32 - bits)
|
||||
return netip.AddrFrom4([4]byte{
|
||||
byte(mask >> 24),
|
||||
byte(mask >> 16),
|
||||
byte(mask >> 8),
|
||||
byte(mask),
|
||||
})
|
||||
}
|
||||
var mask [16]byte
|
||||
for i := 0; i < bits/8; i++ {
|
||||
mask[i] = 0xff
|
||||
}
|
||||
if bits%8 != 0 {
|
||||
mask[bits/8] = ^byte(0) << (8 - bits%8)
|
||||
}
|
||||
return netip.AddrFrom16(mask)
|
||||
}
|
||||
|
||||
func selectGateway(prefix netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
||||
for _, gw := range gateways {
|
||||
if prefix.Addr().Is4() == gw.Addr().Is4() {
|
||||
return gw, nil
|
||||
}
|
||||
}
|
||||
return netip.Prefix{}, fmt.Errorf("no suitable gateway found for prefix %v", prefix)
|
||||
}
|
||||
|
||||
14
overlay/tun_notwin.go
Normal file
14
overlay/tun_notwin.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package overlay
|
||||
|
||||
import "syscall"
|
||||
|
||||
func ioctl(a1, a2, a3 uintptr) error {
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,59 +1,104 @@
|
||||
//go:build openbsd && !e2e_testing
|
||||
// +build openbsd,!e2e_testing
|
||||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"os"
|
||||
"regexp"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type tun struct{}
|
||||
const (
|
||||
SIOCAIFADDR_IN6 = 0x8080691a
|
||||
)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported on OpenBSD")
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
||||
deviceName := c.GetString("tun.dev", "tun")
|
||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime [2]uint32
|
||||
}
|
||||
|
||||
// Create WireGuard TUN device
|
||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||
type ifreq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
data int
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
if !deviceNameRE.MatchString(deviceName) {
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
|
||||
// Get the actual device name
|
||||
actualName, err := tunDevice.Name()
|
||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &wgTun{
|
||||
tunDevice: tunDevice,
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
f: os.NewFile(uintptr(fd), ""),
|
||||
fd: fd,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MaxMTU: mtu,
|
||||
DefaultMTU: mtu,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
// Create OpenBSD-specific route manager
|
||||
t.routeManager = &tun{}
|
||||
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -64,86 +109,221 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
})
|
||||
|
||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (rm *tun) Activate(t *wgTun) error {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device name: %w", err)
|
||||
func (t *tun) Close() error {
|
||||
if t.f != nil {
|
||||
if err := t.f.Close(); err != nil {
|
||||
return fmt.Errorf("error closing tun file: %w", err)
|
||||
}
|
||||
|
||||
// t.f.Close should have handled it for us but let's be extra sure
|
||||
_ = unix.Close(t.fd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.f.Read(buf)
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
if cap(buf) < len(from)+4 {
|
||||
buf = make([]byte, len(from)+4)
|
||||
t.out = buf
|
||||
}
|
||||
buf = buf[:len(from)+4]
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
// Set the MTU
|
||||
rm.SetMTU(t, t.MaxMTU)
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
buf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
buf[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
// Add IP addresses
|
||||
for _, network := range t.vpnNetworks {
|
||||
if err := rm.addIP(t, name, network); err != nil {
|
||||
copy(buf[4:], from)
|
||||
|
||||
n, err := t.f.Write(buf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
if cidr.Addr().Is4() {
|
||||
var req ifreqAlias4
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.DstAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.MaskAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
}
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||
}
|
||||
|
||||
err = addRoute(cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if cidr.Addr().Is6() {
|
||||
var req ifreqAlias6
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
}
|
||||
req.PrefixMask = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
}
|
||||
req.Lifetime[0] = 0xffffffff
|
||||
req.Lifetime[1] = 0xffffffff
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||
}
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err = t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
if err := runCommandBSD("ifconfig", name, "up"); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %w", err)
|
||||
}
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
// Set the routes
|
||||
if err := rm.AddRoutes(t, false); err != nil {
|
||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||
}
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
||||
return
|
||||
}
|
||||
|
||||
if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
|
||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||
}
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||
// On OpenBSD, routes are set via ifconfig and route commands
|
||||
return nil
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device name: %w", err)
|
||||
}
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
// Add route using route command
|
||||
args := []string{"add"}
|
||||
|
||||
if r.Cidr.Addr().Is6() {
|
||||
args = append(args, "-inet6")
|
||||
} else {
|
||||
args = append(args, "-inet")
|
||||
}
|
||||
|
||||
args = append(args, r.Cidr.String(), "-interface", name)
|
||||
|
||||
if r.Metric > 0 {
|
||||
// OpenBSD doesn't support route metrics directly like Linux
|
||||
t.l.WithField("route", r).Warn("Route metrics are not fully supported on OpenBSD")
|
||||
}
|
||||
|
||||
err := runCommandBSD("route", args...)
|
||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
@@ -159,71 +339,131 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Failed to get device name for route removal")
|
||||
return
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
args := []string{"delete"}
|
||||
|
||||
if r.Cidr.Addr().Is6() {
|
||||
args = append(args, "-inet6")
|
||||
} else {
|
||||
args = append(args, "-inet")
|
||||
}
|
||||
|
||||
args = append(args, r.Cidr.String(), "-interface", name)
|
||||
|
||||
err := runCommandBSD("route", args...)
|
||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||
// OpenBSD doesn't support multi-queue TUN devices in the same way as Linux
|
||||
// Return a reader that wraps the same device
|
||||
return &wgTunReader{
|
||||
parent: t,
|
||||
tunDevice: t.tunDevice,
|
||||
offset: 0,
|
||||
l: t.l,
|
||||
}, nil
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
|
||||
addr := network.Addr()
|
||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
if addr.Is4() {
|
||||
// For IPv4: ifconfig tun0 10.0.0.1/24
|
||||
if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
|
||||
return fmt.Errorf("failed to add IPv4 address: %w", err)
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
// For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
|
||||
if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
|
||||
return fmt.Errorf("failed to add IPv6 address: %w", err)
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runCommandBSD(name string, args ...string) error {
|
||||
cmd := exec.Command(name, args...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
//go:build !android && !netbsd && !e2e_testing
|
||||
// +build !android,!netbsd,!e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// wgTun wraps a WireGuard TUN device and implements the overlay.Device interface
|
||||
type wgTun struct {
|
||||
tunDevice wgtun.Device
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
routeChan chan struct{}
|
||||
|
||||
// Platform-specific route management
|
||||
routeManager *tun
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
// BatchReader interface for readers that support vectorized I/O
|
||||
type BatchReader interface {
|
||||
BatchRead(buffers [][]byte, sizes []int) (int, error)
|
||||
}
|
||||
|
||||
// BatchWriter interface for writers that support vectorized I/O
|
||||
type BatchWriter interface {
|
||||
BatchWrite(packets [][]byte) (int, error)
|
||||
}
|
||||
|
||||
// wgTunReader wraps a single TUN queue for multi-queue support
|
||||
type wgTunReader struct {
|
||||
parent *wgTun
|
||||
tunDevice wgtun.Device
|
||||
offset int
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func (t *wgTun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *wgTun) Name() string {
|
||||
name, err := t.tunDevice.Name()
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Failed to get TUN device name")
|
||||
return "unknown"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func (t *wgTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *wgTun) Activate() error {
|
||||
if t.routeManager == nil {
|
||||
return fmt.Errorf("route manager not initialized")
|
||||
}
|
||||
return t.routeManager.Activate(t)
|
||||
}
|
||||
|
||||
// Read implements single-packet read for backward compatibility
|
||||
func (t *wgTun) Read(b []byte) (int, error) {
|
||||
bufs := [][]byte{b}
|
||||
sizes := []int{0}
|
||||
n, err := t.tunDevice.Read(bufs, sizes, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.ErrNoProgress
|
||||
}
|
||||
return sizes[0], nil
|
||||
}
|
||||
|
||||
// Write implements single-packet write for backward compatibility
|
||||
func (t *wgTun) Write(b []byte) (int, error) {
|
||||
bufs := [][]byte{b}
|
||||
offset := 0
|
||||
|
||||
// WireGuard TUN expects the packet data to start at offset 0
|
||||
n, err := t.tunDevice.Write(bufs, offset)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *wgTun) Close() error {
|
||||
if t.routeChan != nil {
|
||||
close(t.routeChan)
|
||||
}
|
||||
|
||||
if t.tunDevice != nil {
|
||||
return t.tunDevice.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *wgTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
// For WireGuard TUN, we need to create separate TUN device instances for multi-queue
|
||||
// The platform-specific implementation will handle this
|
||||
if t.routeManager == nil {
|
||||
return nil, fmt.Errorf("route manager not initialized for multi-queue reader")
|
||||
}
|
||||
|
||||
return t.routeManager.NewMultiQueueReader(t)
|
||||
}
|
||||
|
||||
func (t *wgTun) reload(c *config.C, initial bool) error {
|
||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldDefaultMTU := t.DefaultMTU
|
||||
oldMaxMTU := t.MaxMTU
|
||||
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
|
||||
newMaxMTU := newDefaultMTU
|
||||
for i, r := range routes {
|
||||
if r.MTU == 0 {
|
||||
routes[i].MTU = newDefaultMTU
|
||||
}
|
||||
|
||||
if r.MTU > t.MaxMTU {
|
||||
newMaxMTU = r.MTU
|
||||
}
|
||||
}
|
||||
|
||||
t.MaxMTU = newMaxMTU
|
||||
t.DefaultMTU = newDefaultMTU
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial && t.routeManager != nil {
|
||||
if oldMaxMTU != newMaxMTU {
|
||||
t.routeManager.SetMTU(t, t.MaxMTU)
|
||||
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
||||
}
|
||||
|
||||
if oldDefaultMTU != newDefaultMTU {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
t.l.Warn(err)
|
||||
} else {
|
||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
t.routeManager.RemoveRoutes(t, findRemovedRoutes(routes, *oldRoutes))
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.routeManager.AddRoutes(t, true)
|
||||
if err != nil {
|
||||
// This should never be called since AddRoutes should log its own errors in a reload condition
|
||||
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchRead reads multiple packets from the TUN device using vectorized I/O
|
||||
// The caller provides buffers and sizes slices, and this function returns the number of packets read.
|
||||
func (r *wgTunReader) BatchRead(buffers [][]byte, sizes []int) (int, error) {
|
||||
return r.tunDevice.Read(buffers, sizes, r.offset)
|
||||
}
|
||||
|
||||
// Read implements io.Reader for wgTunReader (single packet for compatibility)
|
||||
func (r *wgTunReader) Read(b []byte) (int, error) {
|
||||
bufs := [][]byte{b}
|
||||
sizes := []int{0}
|
||||
n, err := r.tunDevice.Read(bufs, sizes, r.offset)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.ErrNoProgress
|
||||
}
|
||||
return sizes[0], nil
|
||||
}
|
||||
|
||||
// Write implements io.Writer for wgTunReader
|
||||
func (r *wgTunReader) Write(b []byte) (int, error) {
|
||||
bufs := [][]byte{b}
|
||||
n, err := r.tunDevice.Write(bufs, r.offset)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// BatchWrite writes multiple packets to the TUN device using vectorized I/O
|
||||
func (r *wgTunReader) BatchWrite(packets [][]byte) (int, error) {
|
||||
return r.tunDevice.Write(packets, r.offset)
|
||||
}
|
||||
|
||||
func (r *wgTunReader) Close() error {
|
||||
if r.tunDevice != nil {
|
||||
return r.tunDevice.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,77 +1,84 @@
|
||||
//go:build windows && !e2e_testing
|
||||
// +build windows,!e2e_testing
|
||||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/wintun"
|
||||
"golang.org/x/sys/windows"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
||||
|
||||
type tun struct {
|
||||
luid winipcfg.LUID
|
||||
type winTun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
||||
deviceName := c.GetString("tun.dev", "Nebula")
|
||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||
|
||||
// Create WireGuard TUN device
|
||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||
err := checkWinTunExists()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
||||
}
|
||||
|
||||
// Get the actual device name
|
||||
actualName, err := tunDevice.Name()
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
guid, err := generateGUIDByDeviceName(deviceName)
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
||||
}
|
||||
|
||||
t := &wgTun{
|
||||
tunDevice: tunDevice,
|
||||
t := &winTun{
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MaxMTU: mtu,
|
||||
DefaultMTU: mtu,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
// Create Windows-specific route manager
|
||||
rm := &tun{}
|
||||
|
||||
// Get LUID from the TUN device
|
||||
// The WireGuard TUN device on Windows should provide a LUID() method
|
||||
if nativeTun, ok := tunDevice.(interface{ LUID() uint64 }); ok {
|
||||
rm.luid = winipcfg.LUID(nativeTun.LUID())
|
||||
} else {
|
||||
tunDevice.Close()
|
||||
return nil, fmt.Errorf("failed to get LUID from TUN device")
|
||||
}
|
||||
t.routeManager = rm
|
||||
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
tunDevice.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tunDevice wintun.Device
|
||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||
if err != nil {
|
||||
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
||||
// Trying a second time resolves the issue.
|
||||
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create TUN device failed: %w", err)
|
||||
}
|
||||
}
|
||||
t.tun = tunDevice.(*wintun.NativeTun)
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
@@ -79,140 +86,206 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
})
|
||||
|
||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (rm *tun) Activate(t *wgTun) error {
|
||||
// Set MTU
|
||||
err := rm.setMTU(t, t.MaxMTU)
|
||||
func (t *winTun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set MTU: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Add IP addresses
|
||||
for _, network := range t.vpnNetworks {
|
||||
if err := rm.addIP(t, network); err != nil {
|
||||
return err
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||
}
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
// Add routes
|
||||
if err := rm.AddRoutes(t, false); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *winTun) Activate() error {
|
||||
luid := winipcfg.LUID(t.tun.LUID())
|
||||
|
||||
err := luid.SetIPAddresses(t.vpnNetworks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set address: %w", err)
|
||||
}
|
||||
|
||||
err = t.addRoutes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||
if err := rm.setMTU(t, mtu); err != nil {
|
||||
t.l.WithError(err).Error("Failed to set MTU")
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *tun) setMTU(t *wgTun, mtu int) error {
|
||||
// Set MTU using winipcfg
|
||||
// Note: MTU setting on Windows TUN devices may be handled by the driver
|
||||
// For now, we'll skip explicit MTU setting as the WireGuard TUN handles it
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||
// On Windows, routes are managed differently
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||
func (t *winTun) addRoutes(logErrors bool) error {
|
||||
luid := winipcfg.LUID(t.tun.LUID())
|
||||
routes := *t.Routes.Load()
|
||||
foundDefault4 := false
|
||||
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
if r.MTU > 0 {
|
||||
// Windows route MTU is not directly supported
|
||||
t.l.WithField("route", r).Debug("Route MTU is not supported on Windows")
|
||||
}
|
||||
|
||||
// Use winipcfg to add the route
|
||||
// The rm.luid should have the AddRoute method from winipcfg
|
||||
if len(r.Via) == 0 {
|
||||
t.l.WithField("route", r).Warn("Route has no via address, skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
err := rm.luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
||||
// Add our unsafe route
|
||||
// Windows does not support multipath routes natively, so we install only a single route.
|
||||
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
|
||||
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
|
||||
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
continue
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
|
||||
if !foundDefault4 {
|
||||
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
|
||||
foundDefault4 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ipif, err := luid.IPInterface(windows.AF_INET)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get ip interface: %w", err)
|
||||
}
|
||||
|
||||
ipif.NLMTU = uint32(t.MTU)
|
||||
if foundDefault4 {
|
||||
ipif.UseAutomaticMetric = false
|
||||
ipif.Metric = 0
|
||||
}
|
||||
|
||||
if err := ipif.Set(); err != nil {
|
||||
return fmt.Errorf("failed to set ip interface: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||
func (t *winTun) removeRoutes(routes []Route) error {
|
||||
luid := winipcfg.LUID(t.tun.LUID())
|
||||
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(r.Via) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
err := rm.luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||
// See comment on luid.AddRoute
|
||||
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||
// Windows doesn't support multi-queue TUN devices
|
||||
// Return a reader that wraps the same device
|
||||
return &wgTunReader{
|
||||
parent: t,
|
||||
tunDevice: t.tunDevice,
|
||||
offset: 0,
|
||||
l: t.l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rm *tun) addIP(t *wgTun, network netip.Prefix) error {
|
||||
// Add IP address using winipcfg
|
||||
// SetIPAddresses expects a slice of prefixes
|
||||
err := rm.luid.SetIPAddresses([]netip.Prefix{network})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IP address %s: %w", network, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateGUIDByDeviceName generates a GUID based on the device name
|
||||
func generateGUIDByDeviceName(deviceName string) (*windows.GUID, error) {
|
||||
// Hash the device name to create a deterministic GUID
|
||||
h := crypto.SHA256.New()
|
||||
h.Write([]byte(tunGUIDLabel))
|
||||
h.Write([]byte(deviceName))
|
||||
sum := h.Sum(nil)
|
||||
|
||||
guid := &windows.GUID{
|
||||
Data1: binary.LittleEndian.Uint32(sum[0:4]),
|
||||
Data2: binary.LittleEndian.Uint16(sum[4:6]),
|
||||
Data3: binary.LittleEndian.Uint16(sum[6:8]),
|
||||
}
|
||||
copy(guid.Data4[:], sum[8:16])
|
||||
|
||||
return guid, nil
|
||||
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *winTun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *winTun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *winTun) Read(b []byte) (int, error) {
|
||||
return t.tun.Read(b, 0)
|
||||
}
|
||||
|
||||
func (t *winTun) Write(b []byte) (int, error) {
|
||||
return t.tun.Write(b, 0)
|
||||
}
|
||||
|
||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||
}
|
||||
|
||||
func (t *winTun) Close() error {
|
||||
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
||||
// so to be certain, just remove everything before destroying.
|
||||
luid := winipcfg.LUID(t.tun.LUID())
|
||||
_ = luid.FlushRoutes(windows.AF_INET)
|
||||
_ = luid.FlushIPAddresses(windows.AF_INET)
|
||||
|
||||
_ = luid.FlushRoutes(windows.AF_INET6)
|
||||
_ = luid.FlushIPAddresses(windows.AF_INET6)
|
||||
|
||||
_ = luid.FlushDNS(windows.AF_INET)
|
||||
_ = luid.FlushDNS(windows.AF_INET6)
|
||||
|
||||
return t.tun.Close()
|
||||
}
|
||||
|
||||
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
|
||||
// GUID is 128 bit
|
||||
hash := crypto.MD5.New()
|
||||
|
||||
_, err := hash.Write([]byte(tunGUIDLabel))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = hash.Write([]byte(name))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sum := hash.Sum(nil)
|
||||
|
||||
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
|
||||
}
|
||||
|
||||
func checkWinTunExists() error {
|
||||
myPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
arch := runtime.GOARCH
|
||||
switch arch {
|
||||
case "386":
|
||||
//NOTE: wintun bundles 386 as x86
|
||||
arch = "x86"
|
||||
}
|
||||
|
||||
_, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll"))
|
||||
return err
|
||||
}
|
||||
|
||||
95
pki.go
95
pki.go
@@ -100,55 +100,62 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
currentState := p.cs.Load()
|
||||
if newState.v1Cert != nil {
|
||||
if currentState.v1Cert == nil {
|
||||
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||
}
|
||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||
} else {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new v1 cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
} else if currentState.v1Cert != nil {
|
||||
//TODO: CERT-V2 we should be able to tear this down
|
||||
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
if newState.v2Cert != nil {
|
||||
if currentState.v2Cert == nil {
|
||||
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||
}
|
||||
//adding certs is fine, actually
|
||||
} else {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
} else if currentState.v2Cert != nil {
|
||||
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
||||
if newState.v1Cert == nil {
|
||||
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
||||
}
|
||||
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
||||
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
@@ -516,9 +523,13 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||
}
|
||||
|
||||
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
||||
l.WithField("fingerprint", fp).Info("Blocklisting cert")
|
||||
caPool.BlocklistFingerprint(fp)
|
||||
bl := c.GetStringSlice("pki.blocklist", []string{})
|
||||
if len(bl) > 0 {
|
||||
for _, fp := range bl {
|
||||
caPool.BlocklistFingerprint(fp)
|
||||
}
|
||||
|
||||
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
|
||||
}
|
||||
|
||||
return caPool, nil
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"go.yaml.in/yaml/v3"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type m = map[string]any
|
||||
|
||||
Reference in New Issue
Block a user