mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
13 Commits
e2e-bench-
...
botched-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7423d39f9 | ||
|
|
befba57366 | ||
|
|
2d128a3254 | ||
|
|
c8980d34cf | ||
|
|
98f264cf14 | ||
|
|
aa44f4c7c9 | ||
|
|
419157c407 | ||
|
|
0864852d33 | ||
|
|
2b5aec9a18 | ||
|
|
f0665bee20 | ||
|
|
11da0baab1 | ||
|
|
608904b9dd | ||
|
|
fd1c52127f |
2
.github/workflows/gofmt.yml
vendored
2
.github/workflows/gofmt.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
|
|||||||
20
.github/workflows/release.yml
vendored
20
.github/workflows/release.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
|||||||
name: Build Linux/BSD All
|
name: Build Linux/BSD All
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -24,7 +24,7 @@ jobs:
|
|||||||
mv build/*.tar.gz release
|
mv build/*.tar.gz release
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-latest
|
name: linux-latest
|
||||||
path: release
|
path: release
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
name: Build Windows
|
name: Build Windows
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
mv dist\windows\wintun build\dist\windows\
|
mv dist\windows\wintun build\dist\windows\
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-latest
|
name: windows-latest
|
||||||
path: build
|
path: build
|
||||||
@@ -66,7 +66,7 @@ jobs:
|
|||||||
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -104,7 +104,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: darwin-latest
|
name: darwin-latest
|
||||||
path: ./release/*
|
path: ./release/*
|
||||||
@@ -124,11 +124,11 @@ jobs:
|
|||||||
# be overwritten
|
# be overwritten
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Download artifacts
|
- name: Download artifacts
|
||||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||||
uses: actions/download-artifact@v6
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-latest
|
name: linux-latest
|
||||||
path: artifacts
|
path: artifacts
|
||||||
@@ -160,10 +160,10 @@ jobs:
|
|||||||
needs: [build-linux, build-darwin, build-windows]
|
needs: [build-linux, build-darwin, build-windows]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Download artifacts
|
- name: Download artifacts
|
||||||
uses: actions/download-artifact@v6
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
path: artifacts
|
path: artifacts
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/smoke-extra.yml
vendored
2
.github/workflows/smoke-extra.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/smoke.yml
vendored
2
.github/workflows/smoke.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
|
|||||||
16
.github/workflows/test.yml
vendored
16
.github/workflows/test.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -32,7 +32,7 @@ jobs:
|
|||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v9
|
uses: golangci/golangci-lint-action@v8
|
||||||
with:
|
with:
|
||||||
version: v2.5
|
version: v2.5
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Build test mobile
|
- name: Build test mobile
|
||||||
run: make build-test-mobile
|
run: make build-test-mobile
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v5
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: e2e packet flow linux-latest
|
name: e2e packet flow linux-latest
|
||||||
path: e2e/mermaid/linux-latest
|
path: e2e/mermaid/linux-latest
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -77,7 +77,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
os: [windows-latest, macos-latest]
|
os: [windows-latest, macos-latest]
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -115,7 +115,7 @@ jobs:
|
|||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v9
|
uses: golangci/golangci-lint-action@v8
|
||||||
with:
|
with:
|
||||||
version: v2.5
|
version: v2.5
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ jobs:
|
|||||||
- name: End 2 end
|
- name: End 2 end
|
||||||
run: make e2evv
|
run: make e2evv
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v5
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: e2e packet flow ${{ matrix.os }}
|
name: e2e packet flow ${{ matrix.os }}
|
||||||
path: e2e/mermaid/${{ matrix.os }}
|
path: e2e/mermaid/${{ matrix.os }}
|
||||||
|
|||||||
164
batch_pipeline.go
Normal file
164
batch_pipeline.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/overlay"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// batchPipelines tracks whether the inside device can operate on packet batches
|
||||||
|
// and, if so, holds the shared packet pool sized for the virtio headroom and
|
||||||
|
// payload limits advertised by the device. It also owns the fan-in/fan-out
|
||||||
|
// queues between the TUN readers, encrypt/decrypt workers, and the UDP writers.
|
||||||
|
type batchPipelines struct {
|
||||||
|
enabled bool
|
||||||
|
inside overlay.BatchCapableDevice
|
||||||
|
headroom int
|
||||||
|
payloadCap int
|
||||||
|
pool *overlay.PacketPool
|
||||||
|
batchSize int
|
||||||
|
routines int
|
||||||
|
rxQueues []chan *overlay.Packet
|
||||||
|
txQueues []chan queuedDatagram
|
||||||
|
tunQueues []chan *overlay.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
type queuedDatagram struct {
|
||||||
|
packet *overlay.Packet
|
||||||
|
addr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) init(device overlay.Device, routines int, queueDepth int, maxSegments int) {
|
||||||
|
if device == nil || routines <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bcap, ok := device.(overlay.BatchCapableDevice)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
headroom := bcap.BatchHeadroom()
|
||||||
|
payload := bcap.BatchPayloadCap()
|
||||||
|
if maxSegments < 1 {
|
||||||
|
maxSegments = 1
|
||||||
|
}
|
||||||
|
requiredPayload := udp.MTU * maxSegments
|
||||||
|
if payload < requiredPayload {
|
||||||
|
payload = requiredPayload
|
||||||
|
}
|
||||||
|
batchSize := bcap.BatchSize()
|
||||||
|
if headroom <= 0 || payload <= 0 || batchSize <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bp.enabled = true
|
||||||
|
bp.inside = bcap
|
||||||
|
bp.headroom = headroom
|
||||||
|
bp.payloadCap = payload
|
||||||
|
bp.batchSize = batchSize
|
||||||
|
bp.routines = routines
|
||||||
|
bp.pool = overlay.NewPacketPool(headroom, payload)
|
||||||
|
queueCap := batchSize * defaultBatchQueueDepthFactor
|
||||||
|
if queueDepth > 0 {
|
||||||
|
queueCap = queueDepth
|
||||||
|
}
|
||||||
|
if queueCap < batchSize {
|
||||||
|
queueCap = batchSize
|
||||||
|
}
|
||||||
|
bp.rxQueues = make([]chan *overlay.Packet, routines)
|
||||||
|
bp.txQueues = make([]chan queuedDatagram, routines)
|
||||||
|
bp.tunQueues = make([]chan *overlay.Packet, routines)
|
||||||
|
for i := 0; i < routines; i++ {
|
||||||
|
bp.rxQueues[i] = make(chan *overlay.Packet, queueCap)
|
||||||
|
bp.txQueues[i] = make(chan queuedDatagram, queueCap)
|
||||||
|
bp.tunQueues[i] = make(chan *overlay.Packet, queueCap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) Pool() *overlay.PacketPool {
|
||||||
|
if bp == nil || !bp.enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) Enabled() bool {
|
||||||
|
return bp != nil && bp.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) batchSizeHint() int {
|
||||||
|
if bp == nil || bp.batchSize <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return bp.batchSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) rxQueue(i int) chan *overlay.Packet {
|
||||||
|
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.rxQueues) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.rxQueues[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) txQueue(i int) chan queuedDatagram {
|
||||||
|
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.txQueues) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.txQueues[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) tunQueue(i int) chan *overlay.Packet {
|
||||||
|
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.tunQueues) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.tunQueues[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) txQueueLen(i int) int {
|
||||||
|
q := bp.txQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return len(q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) tunQueueLen(i int) int {
|
||||||
|
q := bp.tunQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return len(q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) enqueueRx(i int, pkt *overlay.Packet) bool {
|
||||||
|
q := bp.rxQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
q <- pkt
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) enqueueTx(i int, pkt *overlay.Packet, addr netip.AddrPort) bool {
|
||||||
|
q := bp.txQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
q <- queuedDatagram{packet: pkt, addr: addr}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) enqueueTun(i int, pkt *overlay.Packet) bool {
|
||||||
|
q := bp.tunQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
q <- pkt
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) newPacket() *overlay.Packet {
|
||||||
|
if bp == nil || !bp.enabled || bp.pool == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.pool.Get()
|
||||||
|
}
|
||||||
109
bits.go
109
bits.go
@@ -9,13 +9,14 @@ type Bits struct {
|
|||||||
length uint64
|
length uint64
|
||||||
current uint64
|
current uint64
|
||||||
bits []bool
|
bits []bool
|
||||||
|
firstSeen bool
|
||||||
lostCounter metrics.Counter
|
lostCounter metrics.Counter
|
||||||
dupeCounter metrics.Counter
|
dupeCounter metrics.Counter
|
||||||
outOfWindowCounter metrics.Counter
|
outOfWindowCounter metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBits(bits uint64) *Bits {
|
func NewBits(bits uint64) *Bits {
|
||||||
b := &Bits{
|
return &Bits{
|
||||||
length: bits,
|
length: bits,
|
||||||
bits: make([]bool, bits, bits),
|
bits: make([]bool, bits, bits),
|
||||||
current: 0,
|
current: 0,
|
||||||
@@ -23,37 +24,34 @@ func NewBits(bits uint64) *Bits {
|
|||||||
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
|
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
|
||||||
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
|
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
// There is no counter value 0, mark it to avoid counting a lost packet later.
|
|
||||||
b.bits[0] = true
|
|
||||||
b.current = 0
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
||||||
// If i is the next number, return true.
|
// If i is the next number, return true.
|
||||||
if i > b.current {
|
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// If i is within the window, check if it's been set already.
|
// If i is within the window, check if it's been set already. The first window will fail this check
|
||||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
if i > b.current-b.length {
|
||||||
|
return !b.bits[i%b.length]
|
||||||
|
}
|
||||||
|
|
||||||
|
// If i is within the first window
|
||||||
|
if i < b.length {
|
||||||
return !b.bits[i%b.length]
|
return !b.bits[i%b.length]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
if l.Level >= logrus.DebugLevel {
|
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||||
// If i is the next number, return true and update current.
|
// If i is the next number, return true and update current.
|
||||||
if i == b.current+1 {
|
if i == b.current+1 {
|
||||||
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
// Report missed packets, we can only understand what was missed after the first window has been gone through
|
||||||
// The very first window can only be tracked as lost once we are on the 2nd window or greater
|
if i > b.length && b.bits[i%b.length] == false {
|
||||||
if b.bits[i%b.length] == false && i > b.length {
|
|
||||||
b.lostCounter.Inc(1)
|
b.lostCounter.Inc(1)
|
||||||
}
|
}
|
||||||
b.bits[i%b.length] = true
|
b.bits[i%b.length] = true
|
||||||
@@ -61,32 +59,61 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// If i is a jump, adjust the window, record lost, update current, and return true
|
// If i packet is greater than current but less than the maximum length of our bitmap,
|
||||||
if i > b.current {
|
// flip everything in between to false and move ahead.
|
||||||
lost := int64(0)
|
if i > b.current && i < b.current+b.length {
|
||||||
// Zero out the bits between the current and the new counter value, limited by the window size,
|
// In between current and i need to be zero'd to allow those packets to come in later
|
||||||
// since the window is shifting
|
for n := b.current + 1; n < i; n++ {
|
||||||
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
|
|
||||||
if b.bits[n%b.length] == false && n > b.length {
|
|
||||||
lost++
|
|
||||||
}
|
|
||||||
b.bits[n%b.length] = false
|
b.bits[n%b.length] = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only record any skipped packets as a result of the window moving further than the window length
|
b.bits[i%b.length] = true
|
||||||
// Any loss within the new window will be accounted for in future calls
|
b.current = i
|
||||||
lost += max(0, int64(i-b.current-b.length))
|
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If i is greater than the delta between current and the total length of our bitmap,
|
||||||
|
// just flip everything in the map and move ahead.
|
||||||
|
if i >= b.current+b.length {
|
||||||
|
// The current window loss will be accounted for later, only record the jump as loss up until then
|
||||||
|
lost := maxInt64(0, int64(i-b.current-b.length))
|
||||||
|
//TODO: explain this
|
||||||
|
if b.current == 0 {
|
||||||
|
lost++
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := range b.bits {
|
||||||
|
// Don't want to count the first window as a loss
|
||||||
|
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
|
||||||
|
//if b.bits[n] == false {
|
||||||
|
// lost++
|
||||||
|
//}
|
||||||
|
b.bits[n] = false
|
||||||
|
}
|
||||||
|
|
||||||
b.lostCounter.Inc(lost)
|
b.lostCounter.Inc(lost)
|
||||||
|
|
||||||
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
|
||||||
|
Debug("Receive window")
|
||||||
|
}
|
||||||
b.bits[i%b.length] = true
|
b.bits[i%b.length] = true
|
||||||
b.current = i
|
b.current = i
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// If i is within the current window but below the current counter,
|
// Allow for the 0 packet to come in within the first window
|
||||||
// Check to see if it's a duplicate
|
if i == 0 && b.firstSeen == false && b.current < b.length {
|
||||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
b.firstSeen = true
|
||||||
if b.current == i || b.bits[i%b.length] == true {
|
b.bits[i%b.length] = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If i is within the window of current minus length (the total pat window size),
|
||||||
|
// allow it and flip to true but to NOT change current. We also have to account for the first window
|
||||||
|
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
|
||||||
|
if b.current == i {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
||||||
Debug("Receive window")
|
Debug("Receive window")
|
||||||
@@ -95,8 +122,18 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if b.bits[i%b.length] == true {
|
||||||
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
|
||||||
|
Debug("Receive window")
|
||||||
|
}
|
||||||
|
b.dupeCounter.Inc(1)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
b.bits[i%b.length] = true
|
b.bits[i%b.length] = true
|
||||||
return true
|
return true
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// In all other cases, fail and don't change current.
|
// In all other cases, fail and don't change current.
|
||||||
@@ -110,3 +147,11 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func maxInt64(a, b int64) int64 {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|||||||
109
bits_test.go
109
bits_test.go
@@ -15,41 +15,48 @@ func TestBits(t *testing.T) {
|
|||||||
assert.Len(t, b.bits, 10)
|
assert.Len(t, b.bits, 10)
|
||||||
|
|
||||||
// This is initialized to zero - receive one. This should work.
|
// This is initialized to zero - receive one. This should work.
|
||||||
|
|
||||||
assert.True(t, b.Check(l, 1))
|
assert.True(t, b.Check(l, 1))
|
||||||
assert.True(t, b.Update(l, 1))
|
u := b.Update(l, 1)
|
||||||
|
assert.True(t, u)
|
||||||
assert.EqualValues(t, 1, b.current)
|
assert.EqualValues(t, 1, b.current)
|
||||||
g := []bool{true, true, false, false, false, false, false, false, false, false}
|
g := []bool{false, true, false, false, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Receive two
|
// Receive two
|
||||||
assert.True(t, b.Check(l, 2))
|
assert.True(t, b.Check(l, 2))
|
||||||
assert.True(t, b.Update(l, 2))
|
u = b.Update(l, 2)
|
||||||
|
assert.True(t, u)
|
||||||
assert.EqualValues(t, 2, b.current)
|
assert.EqualValues(t, 2, b.current)
|
||||||
g = []bool{true, true, true, false, false, false, false, false, false, false}
|
g = []bool{false, true, true, false, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Receive two again - it will fail
|
// Receive two again - it will fail
|
||||||
assert.False(t, b.Check(l, 2))
|
assert.False(t, b.Check(l, 2))
|
||||||
assert.False(t, b.Update(l, 2))
|
u = b.Update(l, 2)
|
||||||
|
assert.False(t, u)
|
||||||
assert.EqualValues(t, 2, b.current)
|
assert.EqualValues(t, 2, b.current)
|
||||||
|
|
||||||
// Jump ahead to 15, which should clear everything and set the 6th element
|
// Jump ahead to 15, which should clear everything and set the 6th element
|
||||||
assert.True(t, b.Check(l, 15))
|
assert.True(t, b.Check(l, 15))
|
||||||
assert.True(t, b.Update(l, 15))
|
u = b.Update(l, 15)
|
||||||
|
assert.True(t, u)
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 15, b.current)
|
||||||
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Mark 14, which is allowed because it is in the window
|
// Mark 14, which is allowed because it is in the window
|
||||||
assert.True(t, b.Check(l, 14))
|
assert.True(t, b.Check(l, 14))
|
||||||
assert.True(t, b.Update(l, 14))
|
u = b.Update(l, 14)
|
||||||
|
assert.True(t, u)
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 15, b.current)
|
||||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Mark 5, which is not allowed because it is not in the window
|
// Mark 5, which is not allowed because it is not in the window
|
||||||
assert.False(t, b.Check(l, 5))
|
assert.False(t, b.Check(l, 5))
|
||||||
assert.False(t, b.Update(l, 5))
|
u = b.Update(l, 5)
|
||||||
|
assert.False(t, u)
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 15, b.current)
|
||||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
@@ -62,29 +69,10 @@ func TestBits(t *testing.T) {
|
|||||||
|
|
||||||
// Walk through a few windows in order
|
// Walk through a few windows in order
|
||||||
b = NewBits(10)
|
b = NewBits(10)
|
||||||
for i := uint64(1); i <= 100; i++ {
|
for i := uint64(0); i <= 100; i++ {
|
||||||
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
||||||
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.False(t, b.Check(l, 1), "Out of window check")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBitsLargeJumps(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
b := NewBits(10)
|
|
||||||
b.lostCounter.Clear()
|
|
||||||
|
|
||||||
b = NewBits(10)
|
|
||||||
b.lostCounter.Clear()
|
|
||||||
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
|
|
||||||
assert.Equal(t, int64(45), b.lostCounter.Count())
|
|
||||||
|
|
||||||
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
|
|
||||||
assert.Equal(t, int64(89), b.lostCounter.Count())
|
|
||||||
|
|
||||||
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
|
|
||||||
assert.Equal(t, int64(188), b.lostCounter.Count())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsDupeCounter(t *testing.T) {
|
func TestBitsDupeCounter(t *testing.T) {
|
||||||
@@ -136,7 +124,8 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
|
|||||||
assert.False(t, b.Update(l, 0))
|
assert.False(t, b.Update(l, 0))
|
||||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
//tODO: make sure lostcounter doesn't increase in orderly increment
|
||||||
|
assert.Equal(t, int64(20), b.lostCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||||
}
|
}
|
||||||
@@ -148,6 +137,8 @@ func TestBitsLostCounter(t *testing.T) {
|
|||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
|
//assert.True(t, b.Update(0))
|
||||||
|
assert.True(t, b.Update(l, 0))
|
||||||
assert.True(t, b.Update(l, 20))
|
assert.True(t, b.Update(l, 20))
|
||||||
assert.True(t, b.Update(l, 21))
|
assert.True(t, b.Update(l, 21))
|
||||||
assert.True(t, b.Update(l, 22))
|
assert.True(t, b.Update(l, 22))
|
||||||
@@ -158,7 +149,7 @@ func TestBitsLostCounter(t *testing.T) {
|
|||||||
assert.True(t, b.Update(l, 27))
|
assert.True(t, b.Update(l, 27))
|
||||||
assert.True(t, b.Update(l, 28))
|
assert.True(t, b.Update(l, 28))
|
||||||
assert.True(t, b.Update(l, 29))
|
assert.True(t, b.Update(l, 29))
|
||||||
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
assert.Equal(t, int64(20), b.lostCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
@@ -167,6 +158,8 @@ func TestBitsLostCounter(t *testing.T) {
|
|||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
|
assert.True(t, b.Update(l, 0))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 9))
|
assert.True(t, b.Update(l, 9))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
// 10 will set 0 index, 0 was already set, no lost packets
|
// 10 will set 0 index, 0 was already set, no lost packets
|
||||||
@@ -221,62 +214,6 @@ func TestBitsLostCounter(t *testing.T) {
|
|||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsLostCounterIssue1(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
b := NewBits(10)
|
|
||||||
b.lostCounter.Clear()
|
|
||||||
b.dupeCounter.Clear()
|
|
||||||
b.outOfWindowCounter.Clear()
|
|
||||||
|
|
||||||
assert.True(t, b.Update(l, 4))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 1))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 9))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 2))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 3))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 5))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 6))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 7))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
// assert.True(t, b.Update(l, 8))
|
|
||||||
assert.True(t, b.Update(l, 10))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 11))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
|
|
||||||
assert.True(t, b.Update(l, 14))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
|
|
||||||
assert.True(t, b.Update(l, 19))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 12))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 13))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 15))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 16))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 17))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 18))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 20))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 21))
|
|
||||||
|
|
||||||
// We missed packet 8 above
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBits(b *testing.B) {
|
func BenchmarkBits(b *testing.B) {
|
||||||
z := NewBits(10)
|
z := NewBits(10)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
|
|||||||
97
cert/pem.go
97
cert/pem.go
@@ -1,8 +1,10 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/hex"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
@@ -138,6 +140,101 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backward compatibility functions for older API
|
||||||
|
func MarshalX25519PublicKey(b []byte) []byte {
|
||||||
|
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalX25519PrivateKey(b []byte) []byte {
|
||||||
|
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPublicKey(curve Curve, b []byte) []byte {
|
||||||
|
return MarshalPublicKeyToPEM(curve, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPrivateKey(curve Curve, b []byte) []byte {
|
||||||
|
return MarshalPrivateKeyToPEM(curve, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NebulaCertificate is a compatibility wrapper for the old API
|
||||||
|
type NebulaCertificate struct {
|
||||||
|
Details NebulaCertificateDetails
|
||||||
|
Signature []byte
|
||||||
|
cert Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
// NebulaCertificateDetails is a compatibility wrapper for certificate details
|
||||||
|
type NebulaCertificateDetails struct {
|
||||||
|
Name string
|
||||||
|
NotBefore time.Time
|
||||||
|
NotAfter time.Time
|
||||||
|
PublicKey []byte
|
||||||
|
IsCA bool
|
||||||
|
Issuer []byte
|
||||||
|
Curve Curve
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
|
||||||
|
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
|
||||||
|
c, rest, err := UnmarshalCertificateFromPEM(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, rest, err
|
||||||
|
}
|
||||||
|
|
||||||
|
issuerBytes, err := func() ([]byte, error) {
|
||||||
|
issuer := c.Issuer()
|
||||||
|
if issuer == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
decoded, err := hex.DecodeString(issuer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return nil, rest, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey := c.PublicKey()
|
||||||
|
if pubKey != nil {
|
||||||
|
pubKey = append([]byte(nil), pubKey...)
|
||||||
|
}
|
||||||
|
|
||||||
|
sig := c.Signature()
|
||||||
|
if sig != nil {
|
||||||
|
sig = append([]byte(nil), sig...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &NebulaCertificate{
|
||||||
|
Details: NebulaCertificateDetails{
|
||||||
|
Name: c.Name(),
|
||||||
|
NotBefore: c.NotBefore(),
|
||||||
|
NotAfter: c.NotAfter(),
|
||||||
|
PublicKey: pubKey,
|
||||||
|
IsCA: c.IsCA(),
|
||||||
|
Issuer: issuerBytes,
|
||||||
|
Curve: c.Curve(),
|
||||||
|
},
|
||||||
|
Signature: sig,
|
||||||
|
cert: c,
|
||||||
|
}, rest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssuerString returns the issuer in hex format for compatibility
|
||||||
|
func (n *NebulaCertificate) IssuerString() string {
|
||||||
|
if n.Details.Issuer == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(n.Details.Issuer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Certificate returns the underlying certificate (read-only)
|
||||||
|
func (n *NebulaCertificate) Certificate() Certificate {
|
||||||
|
return n.cert
|
||||||
|
}
|
||||||
|
|
||||||
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
||||||
// consumed data or an error on failure
|
// consumed data or an error on failure
|
||||||
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||||
|
|||||||
@@ -173,26 +173,23 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||||||
|
|
||||||
var passphrase []byte
|
var passphrase []byte
|
||||||
if !isP11 && *cf.encryption {
|
if !isP11 && *cf.encryption {
|
||||||
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
for i := 0; i < 5; i++ {
|
||||||
|
out.Write([]byte("Enter passphrase: "))
|
||||||
|
passphrase, err = pr.ReadPassword()
|
||||||
|
|
||||||
|
if err == ErrNoTerminal {
|
||||||
|
return fmt.Errorf("out-key must be encrypted interactively")
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("error reading passphrase: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(passphrase) > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(passphrase) == 0 {
|
if len(passphrase) == 0 {
|
||||||
for i := 0; i < 5; i++ {
|
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||||
out.Write([]byte("Enter passphrase: "))
|
|
||||||
passphrase, err = pr.ReadPassword()
|
|
||||||
|
|
||||||
if err == ErrNoTerminal {
|
|
||||||
return fmt.Errorf("out-key must be encrypted interactively")
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("error reading passphrase: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(passphrase) > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(passphrase) == 0 {
|
|
||||||
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -171,17 +171,6 @@ func Test_ca(t *testing.T) {
|
|||||||
assert.Equal(t, pwPromptOb, ob.String())
|
assert.Equal(t, pwPromptOb, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// test encrypted key with passphrase environment variable
|
|
||||||
os.Remove(keyF.Name())
|
|
||||||
os.Remove(crtF.Name())
|
|
||||||
ob.Reset()
|
|
||||||
eb.Reset()
|
|
||||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
|
||||||
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
|
||||||
require.NoError(t, ca(args, ob, eb, testpw))
|
|
||||||
assert.Empty(t, eb.String())
|
|
||||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
|
||||||
|
|
||||||
// read encrypted key file and verify default params
|
// read encrypted key file and verify default params
|
||||||
rb, _ = os.ReadFile(keyF.Name())
|
rb, _ = os.ReadFile(keyF.Name())
|
||||||
k, _ := pem.Decode(rb)
|
k, _ := pem.Decode(rb)
|
||||||
|
|||||||
@@ -5,28 +5,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"runtime/debug"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// A version string that can be set with
|
|
||||||
//
|
|
||||||
// -ldflags "-X main.Build=SOMEVERSION"
|
|
||||||
//
|
|
||||||
// at compile-time.
|
|
||||||
var Build string
|
var Build string
|
||||||
|
|
||||||
func init() {
|
|
||||||
if Build == "" {
|
|
||||||
info, ok := debug.ReadBuildInfo()
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Build = strings.TrimPrefix(info.Main.Version, "v")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type helpError struct {
|
type helpError struct {
|
||||||
s string
|
s string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ type signFlags struct {
|
|||||||
func newSignFlags() *signFlags {
|
func newSignFlags() *signFlags {
|
||||||
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
|
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
|
||||||
sf.set.Usage = func() {}
|
sf.set.Usage = func() {}
|
||||||
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA")
|
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
|
||||||
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
|
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
|
||||||
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
|
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
|
||||||
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
|
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
|
||||||
@@ -116,28 +116,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
// naively attempt to decode the private key as though it is not encrypted
|
// naively attempt to decode the private key as though it is not encrypted
|
||||||
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
|
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
|
||||||
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
|
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
|
||||||
|
// ask for a passphrase until we get one
|
||||||
var passphrase []byte
|
var passphrase []byte
|
||||||
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
for i := 0; i < 5; i++ {
|
||||||
if len(passphrase) == 0 {
|
out.Write([]byte("Enter passphrase: "))
|
||||||
// ask for a passphrase until we get one
|
passphrase, err = pr.ReadPassword()
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
out.Write([]byte("Enter passphrase: "))
|
|
||||||
passphrase, err = pr.ReadPassword()
|
|
||||||
|
|
||||||
if errors.Is(err, ErrNoTerminal) {
|
if errors.Is(err, ErrNoTerminal) {
|
||||||
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
|
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return fmt.Errorf("error reading password: %s", err)
|
return fmt.Errorf("error reading password: %s", err)
|
||||||
}
|
|
||||||
|
|
||||||
if len(passphrase) > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if len(passphrase) == 0 {
|
|
||||||
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
|
if len(passphrase) > 0 {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(passphrase) == 0 {
|
||||||
|
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
|
||||||
|
}
|
||||||
|
|
||||||
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
|
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
|
return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
|
||||||
@@ -167,10 +165,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
return fmt.Errorf("ca certificate is expired")
|
return fmt.Errorf("ca certificate is expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
if version == 0 {
|
|
||||||
version = caCert.Version()
|
|
||||||
}
|
|
||||||
|
|
||||||
// if no duration is given, expire one second before the root expires
|
// if no duration is given, expire one second before the root expires
|
||||||
if *sf.duration <= 0 {
|
if *sf.duration <= 0 {
|
||||||
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
|
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
|
||||||
@@ -283,19 +277,21 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
notBefore := time.Now()
|
notBefore := time.Now()
|
||||||
notAfter := notBefore.Add(*sf.duration)
|
notAfter := notBefore.Add(*sf.duration)
|
||||||
|
|
||||||
switch version {
|
if version == 0 || version == cert.Version1 {
|
||||||
case cert.Version1:
|
// Make sure we at least have an ip
|
||||||
// Make sure we have only one ipv4 address
|
|
||||||
if len(v4Networks) != 1 {
|
if len(v4Networks) != 1 {
|
||||||
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(v6Networks) > 0 {
|
if version == cert.Version1 {
|
||||||
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses")
|
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
|
||||||
}
|
if len(v6Networks) > 0 {
|
||||||
|
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
|
||||||
|
}
|
||||||
|
|
||||||
if len(v6UnsafeNetworks) > 0 {
|
if len(v6UnsafeNetworks) > 0 {
|
||||||
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
|
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &cert.TBSCertificate{
|
t := &cert.TBSCertificate{
|
||||||
@@ -325,8 +321,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
crts = append(crts, nc)
|
crts = append(crts, nc)
|
||||||
|
}
|
||||||
|
|
||||||
case cert.Version2:
|
if version == 0 || version == cert.Version2 {
|
||||||
t := &cert.TBSCertificate{
|
t := &cert.TBSCertificate{
|
||||||
Version: cert.Version2,
|
Version: cert.Version2,
|
||||||
Name: *sf.name,
|
Name: *sf.name,
|
||||||
@@ -354,9 +351,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
crts = append(crts, nc)
|
crts = append(crts, nc)
|
||||||
default:
|
|
||||||
// this should be unreachable
|
|
||||||
return fmt.Errorf("invalid version: %d", version)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isP11 && *sf.inPubPath == "" {
|
if !isP11 && *sf.inPubPath == "" {
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
|
|||||||
" -unsafe-networks string\n"+
|
" -unsafe-networks string\n"+
|
||||||
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
|
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
|
||||||
" -version uint\n"+
|
" -version uint\n"+
|
||||||
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n",
|
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
||||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
|
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -379,15 +379,6 @@ func Test_signCert(t *testing.T) {
|
|||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// test with the proper password in the environment
|
|
||||||
os.Remove(crtF.Name())
|
|
||||||
os.Remove(keyF.Name())
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
|
||||||
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
|
||||||
require.NoError(t, signCert(args, ob, eb, testpw))
|
|
||||||
assert.Empty(t, eb.String())
|
|
||||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
|
||||||
|
|
||||||
// test with the wrong password
|
// test with the wrong password
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
@@ -398,17 +389,6 @@ func Test_signCert(t *testing.T) {
|
|||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// test with the wrong password in environment
|
|
||||||
ob.Reset()
|
|
||||||
eb.Reset()
|
|
||||||
|
|
||||||
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
|
|
||||||
assert.Empty(t, ob.String())
|
|
||||||
assert.Empty(t, eb.String())
|
|
||||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
|
||||||
|
|
||||||
// test with the user not entering a password
|
// test with the user not entering a password
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"runtime/debug"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
@@ -20,17 +18,6 @@ import (
|
|||||||
// at compile-time.
|
// at compile-time.
|
||||||
var Build string
|
var Build string
|
||||||
|
|
||||||
func init() {
|
|
||||||
if Build == "" {
|
|
||||||
info, ok := debug.ReadBuildInfo()
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Build = strings.TrimPrefix(info.Main.Version, "v")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
serviceFlag := flag.String("service", "", "Control the system service.")
|
serviceFlag := flag.String("service", "", "Control the system service.")
|
||||||
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"runtime/debug"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
@@ -20,17 +18,6 @@ import (
|
|||||||
// at compile-time.
|
// at compile-time.
|
||||||
var Build string
|
var Build string
|
||||||
|
|
||||||
func init() {
|
|
||||||
if Build == "" {
|
|
||||||
info, ok := debug.ReadBuildInfo()
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Build = strings.TrimPrefix(info.Main.Version, "v")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
||||||
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")
|
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type C struct {
|
type C struct {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
|
|||||||
@@ -50,6 +50,11 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
|
|||||||
}
|
}
|
||||||
|
|
||||||
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
||||||
|
|
||||||
|
b := NewBits(ReplayWindow)
|
||||||
|
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
|
||||||
|
b.Update(l, 0)
|
||||||
|
|
||||||
hs, err := noise.NewHandshakeState(noise.Config{
|
hs, err := noise.NewHandshakeState(noise.Config{
|
||||||
CipherSuite: ncs,
|
CipherSuite: ncs,
|
||||||
Random: rand.Reader,
|
Random: rand.Reader,
|
||||||
@@ -69,7 +74,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
|
|||||||
ci := &ConnectionState{
|
ci := &ConnectionState{
|
||||||
H: hs,
|
H: hs,
|
||||||
initiator: initiator,
|
initiator: initiator,
|
||||||
window: NewBits(ReplayWindow),
|
window: b,
|
||||||
myCert: crt,
|
myCert: crt,
|
||||||
}
|
}
|
||||||
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
||||||
|
|||||||
@@ -174,10 +174,6 @@ func (c *Control) GetHostmap() *HostMap {
|
|||||||
return c.f.hostMap
|
return c.f.hostMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetF() *Interface {
|
|
||||||
return c.f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) GetCertState() *CertState {
|
func (c *Control) GetCertState() *CertState {
|
||||||
return c.f.pki.getCertState()
|
return c.f.pki.getCertState()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,17 +20,16 @@ import (
|
|||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkHotPath(b *testing.B) {
|
func BenchmarkHotPath(b *testing.B) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Put their info in our lighthouse
|
// Put their info in our lighthouse
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
// Start the servers
|
||||||
myControl.Start()
|
myControl.Start()
|
||||||
@@ -39,9 +38,6 @@ func BenchmarkHotPath(b *testing.B) {
|
|||||||
r := router.NewR(b, myControl, theirControl)
|
r := router.NewR(b, myControl, theirControl)
|
||||||
r.CancelFlowLogs()
|
r.CancelFlowLogs()
|
||||||
|
|
||||||
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
_ = r.RouteForAllUntilTxTun(theirControl)
|
_ = r.RouteForAllUntilTxTun(theirControl)
|
||||||
@@ -51,39 +47,6 @@ func BenchmarkHotPath(b *testing.B) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkHotPathRelay(b *testing.B) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(b, myControl, relayControl, theirControl)
|
|
||||||
r.CancelFlowLogs()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
relayControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
|
||||||
_ = r.RouteForAllUntilTxTun(theirControl)
|
|
||||||
}
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
relayControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGoodHandshake(t *testing.T) {
|
func TestGoodHandshake(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
@@ -134,41 +97,6 @@ func TestGoodHandshake(t *testing.T) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGoodHandshakeNoOverlap(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
|
|
||||||
|
|
||||||
// Put their info in our lighthouse
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
empty := []byte{}
|
|
||||||
t.Log("do something to cause a handshake")
|
|
||||||
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
|
|
||||||
|
|
||||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
|
||||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
|
||||||
|
|
||||||
t.Log("Get their stage 1 packet")
|
|
||||||
stage1Packet := theirControl.GetFromUDP(true)
|
|
||||||
|
|
||||||
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
|
|
||||||
myControl.InjectUDPPacket(stage1Packet)
|
|
||||||
|
|
||||||
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
|
|
||||||
myControl.WaitForType(header.Test, 0, theirControl)
|
|
||||||
|
|
||||||
t.Log("Make sure our host infos are correct")
|
|
||||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWrongResponderHandshake(t *testing.T) {
|
func TestWrongResponderHandshake(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
@@ -536,35 +464,6 @@ func TestRelays(t *testing.T) {
|
|||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRelaysDontCareAboutIps(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
relayControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
|
||||||
r.Log("Assert the tunnel works")
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReestablishRelays(t *testing.T) {
|
func TestReestablishRelays(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
@@ -1328,109 +1227,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
|
||||||
|
|
||||||
o := m{
|
|
||||||
"static_host_map": m{
|
|
||||||
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
|
|
||||||
},
|
|
||||||
"lighthouse": m{
|
|
||||||
"hosts": []string{lhVpnIpNet[0].Addr().String()},
|
|
||||||
"local_allow_list": m{
|
|
||||||
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
|
|
||||||
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
|
|
||||||
"10.0.0.0/24": true,
|
|
||||||
"::/0": false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
|
|
||||||
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, lhControl, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
lhControl.Start()
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Stand up an ipv6 tunnel between me and them")
|
|
||||||
assert.True(t, myVpnIpNet[1].Addr().Is6())
|
|
||||||
assert.True(t, theirVpnIpNet[1].Addr().Is6())
|
|
||||||
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
|
|
||||||
|
|
||||||
lhControl.Stop()
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
|
||||||
unsafePrefix := "192.168.6.0/24"
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
|
|
||||||
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
|
|
||||||
myCfg := m{
|
|
||||||
"tun": m{
|
|
||||||
"unsafe_routes": []m{route},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
|
|
||||||
t.Logf("my config %v", myConfig)
|
|
||||||
// Put their info in our lighthouse
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
spookyDest := netip.MustParseAddr("192.168.6.4")
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
|
||||||
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
|
||||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
|
||||||
|
|
||||||
t.Log("Get their stage 1 packet so that we can play with it")
|
|
||||||
stage1Packet := theirControl.GetFromUDP(true)
|
|
||||||
|
|
||||||
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
|
|
||||||
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
|
|
||||||
badPacket := stage1Packet.Copy()
|
|
||||||
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
|
|
||||||
myControl.InjectUDPPacket(badPacket)
|
|
||||||
|
|
||||||
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
|
|
||||||
myControl.InjectUDPPacket(stage1Packet)
|
|
||||||
|
|
||||||
t.Log("Wait until we see my cached packet come through")
|
|
||||||
myControl.WaitForType(1, 0, theirControl)
|
|
||||||
|
|
||||||
t.Log("Make sure our host infos are correct")
|
|
||||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
|
||||||
|
|
||||||
t.Log("Get that cached packet and make sure it looks right")
|
|
||||||
myCachedPacket := theirControl.GetFromTun(true)
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
|
|
||||||
|
|
||||||
//reply
|
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
|
|
||||||
//wait for reply
|
|
||||||
theirControl.WaitForType(1, 0, myControl)
|
|
||||||
theirCachedPacket := myControl.GetFromTun(true)
|
|
||||||
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
|
|
||||||
|
|
||||||
t.Log("Do a bidirectional tunnel test")
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -22,14 +22,15 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"gopkg.in/yaml.v3"
|
||||||
"go.yaml.in/yaml/v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
var vpnNetworks []netip.Prefix
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
@@ -55,54 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
budpIp[3] = 239
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||||
}
|
|
||||||
|
|
||||||
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnNetworks) == 0 {
|
|
||||||
panic("no vpn networks")
|
|
||||||
}
|
|
||||||
|
|
||||||
firewallInbound := []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}}
|
|
||||||
|
|
||||||
var unsafeNetworks []netip.Prefix
|
|
||||||
if sUnsafeNetworks != "" {
|
|
||||||
firewallInbound = []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
"local_cidr": "0.0.0.0/0",
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
|
|
||||||
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
unsafeNetworks = append(unsafeNetworks, x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
|
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -122,7 +76,11 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
|||||||
"port": "any",
|
"port": "any",
|
||||||
"host": "any",
|
"host": "any",
|
||||||
}},
|
}},
|
||||||
"inbound": firewallInbound,
|
"inbound": []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}},
|
||||||
},
|
},
|
||||||
//"handshakes": m{
|
//"handshakes": m{
|
||||||
// "try_interval": "1s",
|
// "try_interval": "1s",
|
||||||
@@ -292,7 +250,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
||||||
// Send a packet from them to me
|
// Send a packet from them to me
|
||||||
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
|
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
|
||||||
bPacket := r.RouteForAllUntilTxTun(controlA)
|
bPacket := r.RouteForAllUntilTxTun(controlA)
|
||||||
@@ -304,14 +262,14 @@ func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *n
|
|||||||
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertHostInfoPair(t testing.TB, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
|
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
|
||||||
// Get both host infos
|
// Get both host infos
|
||||||
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
||||||
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
||||||
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||||
|
|
||||||
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
||||||
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||||
|
|
||||||
// Check that both vpn and real addr are correct
|
// Check that both vpn and real addr are correct
|
||||||
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
||||||
@@ -325,7 +283,7 @@ func assertHostInfoPair(t testing.TB, addrA, addrB netip.AddrPort, vpnNetsA, vpn
|
|||||||
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
|
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
if toIp.Is6() {
|
if toIp.Is6() {
|
||||||
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
|
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
|
||||||
} else {
|
} else {
|
||||||
@@ -333,7 +291,7 @@ func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
|
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
|
||||||
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
||||||
assert.NotNil(t, v6, "No ipv6 data found")
|
assert.NotNil(t, v6, "No ipv6 data found")
|
||||||
@@ -352,7 +310,7 @@ func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
|
|||||||
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
||||||
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
||||||
assert.NotNil(t, v4, "No ipv4 data found")
|
assert.NotNil(t, v4, "No ipv4 data found")
|
||||||
|
|||||||
@@ -318,50 +318,3 @@ func TestCertMismatchCorrection(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCrossStackRelaysWork(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
|
||||||
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
|
||||||
|
|
||||||
//myVpnV4 := myVpnIpNet[0]
|
|
||||||
myVpnV6 := myVpnIpNet[1]
|
|
||||||
relayVpnV4 := relayVpnIpNet[0]
|
|
||||||
relayVpnV6 := relayVpnIpNet[1]
|
|
||||||
theirVpnV6 := theirVpnIpNet[0]
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
relayControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
|
||||||
r.Log("Assert the tunnel works")
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
|
||||||
|
|
||||||
t.Log("reply?")
|
|
||||||
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
|
||||||
p = r.RouteForAllUntilTxTun(myControl)
|
|
||||||
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
|
||||||
//t.Log("finish up")
|
|
||||||
//myControl.Stop()
|
|
||||||
//theirControl.Stop()
|
|
||||||
//relayControl.Stop()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -383,9 +383,8 @@ firewall:
|
|||||||
# host: `any` or a literal hostname, ie `test-host`
|
# host: `any` or a literal hostname, ie `test-host`
|
||||||
# group: `any` or a literal group name, ie `default-group`
|
# group: `any` or a literal group name, ie `default-group`
|
||||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||||
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
|
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
|
||||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
|
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
|
||||||
# This can be used to filter destinations when using unsafe_routes.
|
|
||||||
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
|
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
|
||||||
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
|
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
|
||||||
# ca_name: An issuing CA name
|
# ca_name: An issuing CA name
|
||||||
|
|||||||
208
firewall.go
208
firewall.go
@@ -8,7 +8,6 @@ import (
|
|||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -23,7 +22,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type FirewallInterface interface {
|
type FirewallInterface interface {
|
||||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
|
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
@@ -248,11 +247,22 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
|
||||||
|
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
|
||||||
|
// https://github.com/golang/go/issues/14131
|
||||||
|
sIp := ""
|
||||||
|
if ip.IsValid() {
|
||||||
|
sIp = ip.String()
|
||||||
|
}
|
||||||
|
lIp := ""
|
||||||
|
if localIp.IsValid() {
|
||||||
|
lIp = localIp.String()
|
||||||
|
}
|
||||||
|
|
||||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||||
ruleString := fmt.Sprintf(
|
ruleString := fmt.Sprintf(
|
||||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
||||||
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
|
incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
|
||||||
)
|
)
|
||||||
f.rules += ruleString + "\n"
|
f.rules += ruleString + "\n"
|
||||||
|
|
||||||
@@ -260,7 +270,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
if !incoming {
|
if !incoming {
|
||||||
direction = "outgoing"
|
direction = "outgoing"
|
||||||
}
|
}
|
||||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
|
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
|
||||||
Info("Firewall rule added")
|
Info("Firewall rule added")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -287,7 +297,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
return fmt.Errorf("unknown protocol %v", proto)
|
return fmt.Errorf("unknown protocol %v", proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleHash returns a hash representation of all inbound and outbound rules
|
// GetRuleHash returns a hash representation of all inbound and outbound rules
|
||||||
@@ -327,6 +337,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, t := range rs {
|
for i, t := range rs {
|
||||||
|
var groups []string
|
||||||
r, err := convertRule(l, t, table, i)
|
r, err := convertRule(l, t, table, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
||||||
@@ -336,10 +347,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
|
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
|
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
|
||||||
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
|
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(r.Groups) > 0 {
|
||||||
|
groups = r.Groups
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Group != "" {
|
||||||
|
// Check if we have both groups and group provided in the rule config
|
||||||
|
if len(groups) > 0 {
|
||||||
|
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
groups = []string{r.Group}
|
||||||
|
}
|
||||||
|
|
||||||
var sPort, errPort string
|
var sPort, errPort string
|
||||||
if r.Code != "" {
|
if r.Code != "" {
|
||||||
errPort = "code"
|
errPort = "code"
|
||||||
@@ -368,25 +392,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Cidr != "" && r.Cidr != "any" {
|
var cidr netip.Prefix
|
||||||
_, err = netip.ParsePrefix(r.Cidr)
|
if r.Cidr != "" {
|
||||||
|
cidr, err = netip.ParsePrefix(r.Cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.LocalCidr != "" && r.LocalCidr != "any" {
|
var localCidr netip.Prefix
|
||||||
_, err = netip.ParsePrefix(r.LocalCidr)
|
if r.LocalCidr != "" {
|
||||||
|
localCidr, err = netip.ParsePrefix(r.LocalCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if warning := r.sanity(); warning != nil {
|
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
|
||||||
l.Warnf("%s rule #%v; %s", table, i, warning)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
|
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
|
||||||
}
|
}
|
||||||
@@ -395,45 +417,30 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrUnknownNetworkType = errors.New("unknown network type")
|
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||||
var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
|
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||||
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
|
|
||||||
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
|
|
||||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||||
|
|
||||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||||
// returns nil if the packet should not be dropped.
|
// returns nil if the packet should not be dropped.
|
||||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) error {
|
||||||
// Check if we spoke to this tuple, if we did then allow this packet
|
// Check if we spoke to this tuple, if we did then allow this packet
|
||||||
if f.inConns(fp, h, caPool, localCache) {
|
if f.inConns(fp, h, caPool, localCache) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
// Make sure remote address matches nebula certificate
|
||||||
if h.networks == nil {
|
if h.networks != nil {
|
||||||
// Simple case: Certificate has one address and no unsafe networks
|
if !h.networks.Contains(fp.RemoteAddr) {
|
||||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
// Simple case: Certificate has one address and no unsafe networks
|
||||||
if !ok {
|
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
switch nwType {
|
|
||||||
case NetworkTypeVPN:
|
|
||||||
break // nothing special
|
|
||||||
case NetworkTypeVPNPeer:
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
|
||||||
case NetworkTypeUnsafe:
|
|
||||||
break // nothing special, one day this may have different FW rules
|
|
||||||
default:
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrUnknownNetworkType //should never happen
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
@@ -483,11 +490,9 @@ func (f *Firewall) EmitStats() {
|
|||||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) bool {
|
||||||
if localCache != nil {
|
if localCache != nil && localCache.Has(fp) {
|
||||||
if _, ok := localCache[fp]; ok {
|
return true
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
conntrack.Lock()
|
conntrack.Lock()
|
||||||
@@ -552,7 +557,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
|
|
||||||
if localCache != nil {
|
if localCache != nil {
|
||||||
localCache[fp] = struct{}{}
|
localCache.Add(fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@@ -633,7 +638,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
|
||||||
if startPort > endPort {
|
if startPort > endPort {
|
||||||
return fmt.Errorf("start port was lower than end port")
|
return fmt.Errorf("start port was lower than end port")
|
||||||
}
|
}
|
||||||
@@ -646,7 +651,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil {
|
if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -677,7 +682,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
|
|||||||
return fp[firewall.PortAny].match(p, c, caPool)
|
return fp[firewall.PortAny].match(p, c, caPool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error {
|
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
|
||||||
fr := func() *FirewallRule {
|
fr := func() *FirewallRule {
|
||||||
return &FirewallRule{
|
return &FirewallRule{
|
||||||
Hosts: make(map[string]*firewallLocalCIDR),
|
Hosts: make(map[string]*firewallLocalCIDR),
|
||||||
@@ -691,14 +696,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
|
|||||||
fc.Any = fr()
|
fc.Any = fr()
|
||||||
}
|
}
|
||||||
|
|
||||||
return fc.Any.addRule(f, groups, host, cidr, localCidr)
|
return fc.Any.addRule(f, groups, host, ip, localIp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if caSha != "" {
|
if caSha != "" {
|
||||||
if _, ok := fc.CAShas[caSha]; !ok {
|
if _, ok := fc.CAShas[caSha]; !ok {
|
||||||
fc.CAShas[caSha] = fr()
|
fc.CAShas[caSha] = fr()
|
||||||
}
|
}
|
||||||
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr)
|
err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -708,7 +713,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
|
|||||||
if _, ok := fc.CANames[caName]; !ok {
|
if _, ok := fc.CANames[caName]; !ok {
|
||||||
fc.CANames[caName] = fr()
|
fc.CANames[caName] = fr()
|
||||||
}
|
}
|
||||||
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr)
|
err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -740,24 +745,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
|
|||||||
return fc.CANames[s.Certificate.Name()].match(p, c)
|
return fc.CANames[s.Certificate.Name()].match(p, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error {
|
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
|
||||||
flc := func() *firewallLocalCIDR {
|
flc := func() *firewallLocalCIDR {
|
||||||
return &firewallLocalCIDR{
|
return &firewallLocalCIDR{
|
||||||
LocalCIDR: new(bart.Lite),
|
LocalCIDR: new(bart.Lite),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if fr.isAny(groups, host, cidr) {
|
if fr.isAny(groups, host, ip) {
|
||||||
if fr.Any == nil {
|
if fr.Any == nil {
|
||||||
fr.Any = flc()
|
fr.Any = flc()
|
||||||
}
|
}
|
||||||
|
|
||||||
return fr.Any.addRule(f, localCidr)
|
return fr.Any.addRule(f, localCIDR)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(groups) > 0 {
|
if len(groups) > 0 {
|
||||||
nlc := flc()
|
nlc := flc()
|
||||||
err := nlc.addRule(f, localCidr)
|
err := nlc.addRule(f, localCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -773,34 +778,30 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localC
|
|||||||
if nlc == nil {
|
if nlc == nil {
|
||||||
nlc = flc()
|
nlc = flc()
|
||||||
}
|
}
|
||||||
err := nlc.addRule(f, localCidr)
|
err := nlc.addRule(f, localCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fr.Hosts[host] = nlc
|
fr.Hosts[host] = nlc
|
||||||
}
|
}
|
||||||
|
|
||||||
if cidr != "" {
|
if ip.IsValid() {
|
||||||
c, err := netip.ParsePrefix(cidr)
|
nlc, _ := fr.CIDR.Get(ip)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
nlc, _ := fr.CIDR.Get(c)
|
|
||||||
if nlc == nil {
|
if nlc == nil {
|
||||||
nlc = flc()
|
nlc = flc()
|
||||||
}
|
}
|
||||||
err = nlc.addRule(f, localCidr)
|
err := nlc.addRule(f, localCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fr.CIDR.Insert(c, nlc)
|
fr.CIDR.Insert(ip, nlc)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
|
func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
|
||||||
if len(groups) == 0 && host == "" && cidr == "" {
|
if len(groups) == 0 && host == "" && !ip.IsValid() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -814,7 +815,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if cidr == "any" {
|
if ip.IsValid() && ip.Bits() == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -866,13 +867,8 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
|
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
||||||
if localCidr == "any" {
|
if !localIp.IsValid() {
|
||||||
flc.Any = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if localCidr == "" {
|
|
||||||
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
||||||
flc.Any = true
|
flc.Any = true
|
||||||
return nil
|
return nil
|
||||||
@@ -883,13 +879,12 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
} else if localIp.Bits() == 0 {
|
||||||
|
flc.Any = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := netip.ParsePrefix(localCidr)
|
flc.LocalCIDR.Insert(localIp)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
flc.LocalCIDR.Insert(c)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -910,6 +905,7 @@ type rule struct {
|
|||||||
Code string
|
Code string
|
||||||
Proto string
|
Proto string
|
||||||
Host string
|
Host string
|
||||||
|
Group string
|
||||||
Groups []string
|
Groups []string
|
||||||
Cidr string
|
Cidr string
|
||||||
LocalCidr string
|
LocalCidr string
|
||||||
@@ -951,8 +947,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
|||||||
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
||||||
m["group"] = v[0]
|
m["group"] = v[0]
|
||||||
}
|
}
|
||||||
|
r.Group = toString("group", m)
|
||||||
singleGroup := toString("group", m)
|
|
||||||
|
|
||||||
if rg, ok := m["groups"]; ok {
|
if rg, ok := m["groups"]; ok {
|
||||||
switch reflect.TypeOf(rg).Kind() {
|
switch reflect.TypeOf(rg).Kind() {
|
||||||
@@ -969,60 +964,9 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//flatten group vs groups
|
|
||||||
if singleGroup != "" {
|
|
||||||
// Check if we have both groups and group provided in the rule config
|
|
||||||
if len(r.Groups) > 0 {
|
|
||||||
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
|
|
||||||
}
|
|
||||||
r.Groups = []string{singleGroup}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
|
|
||||||
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
|
|
||||||
func (r *rule) sanity() error {
|
|
||||||
//port, proto, local_cidr are AND, no need to check here
|
|
||||||
//ca_sha and ca_name don't have a wildcard value, no need to check here
|
|
||||||
groupsEmpty := len(r.Groups) == 0
|
|
||||||
hostEmpty := r.Host == ""
|
|
||||||
cidrEmpty := r.Cidr == ""
|
|
||||||
|
|
||||||
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
|
|
||||||
return nil //no content!
|
|
||||||
}
|
|
||||||
|
|
||||||
groupsHasAny := slices.Contains(r.Groups, "any")
|
|
||||||
if groupsHasAny && len(r.Groups) > 1 {
|
|
||||||
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Host == "any" {
|
|
||||||
if !groupsEmpty {
|
|
||||||
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cidrEmpty {
|
|
||||||
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if groupsHasAny {
|
|
||||||
if !hostEmpty && r.Host != "any" {
|
|
||||||
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
|
|
||||||
}
|
|
||||||
if !cidrEmpty {
|
|
||||||
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//todo alert on cidr-any
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parsePort(s string) (startPort, endPort int32, err error) {
|
func parsePort(s string) (startPort, endPort int32, err error) {
|
||||||
if s == "any" {
|
if s == "any" {
|
||||||
startPort = firewall.PortAny
|
startPort = firewall.PortAny
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -9,13 +10,58 @@ import (
|
|||||||
|
|
||||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||||
// has been seen in the conntrack table.
|
// has been seen in the conntrack table.
|
||||||
type ConntrackCache map[Packet]struct{}
|
type ConntrackCache struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
entries map[Packet]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConntrackCache() *ConntrackCache {
|
||||||
|
return &ConntrackCache{entries: make(map[Packet]struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConntrackCache) Has(p Packet) bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
_, ok := c.entries[p]
|
||||||
|
c.mu.Unlock()
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConntrackCache) Add(p Packet) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
c.entries[p] = struct{}{}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConntrackCache) Len() int {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
l := len(c.entries)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConntrackCache) Reset(capHint int) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
c.entries = make(map[Packet]struct{}, capHint)
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
type ConntrackCacheTicker struct {
|
type ConntrackCacheTicker struct {
|
||||||
cacheV uint64
|
cacheV uint64
|
||||||
cacheTick atomic.Uint64
|
cacheTick atomic.Uint64
|
||||||
|
|
||||||
cache ConntrackCache
|
cache *ConntrackCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
||||||
@@ -23,9 +69,7 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &ConntrackCacheTicker{
|
c := &ConntrackCacheTicker{cache: newConntrackCache()}
|
||||||
cache: ConntrackCache{},
|
|
||||||
}
|
|
||||||
|
|
||||||
go c.tick(d)
|
go c.tick(d)
|
||||||
|
|
||||||
@@ -41,17 +85,17 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
|||||||
|
|
||||||
// Get checks if the cache ticker has moved to the next version before returning
|
// Get checks if the cache ticker has moved to the next version before returning
|
||||||
// the map. If it has moved, we reset the map.
|
// the map. If it has moved, we reset the map.
|
||||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
||||||
c.cacheV = tick
|
c.cacheV = tick
|
||||||
if ll := len(c.cache); ll > 0 {
|
if ll := c.cache.Len(); ll > 0 {
|
||||||
if l.Level == logrus.DebugLevel {
|
if l.Level == logrus.DebugLevel {
|
||||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||||
}
|
}
|
||||||
c.cache = make(ConntrackCache, ll)
|
c.cache.Reset(ll)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
445
firewall_test.go
445
firewall_test.go
@@ -8,8 +8,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -73,114 +71,85 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
ti6, err := netip.ParsePrefix("fd12::34/128")
|
ti6, err := netip.ParsePrefix("fd12::34/128")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
// An empty rule is any
|
// An empty rule is any
|
||||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
||||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
||||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
||||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
|
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
|
|
||||||
assert.True(t, table.Any)
|
|
||||||
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
anyIp6, err := netip.ParsePrefix("::/0")
|
anyIp6, err := netip.ParsePrefix("::/0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
|
|
||||||
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
|
|
||||||
assert.True(t, table.Any)
|
|
||||||
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
|
|
||||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
|
|
||||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
|
|
||||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
|
|
||||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
// Test error conditions
|
// Test error conditions
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
|
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
|
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -205,10 +174,10 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, &c)
|
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
@@ -227,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,9 +226,6 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -284,10 +250,10 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, &c)
|
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
@@ -306,28 +272,28 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,12 +304,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
||||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "")
|
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
||||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "")
|
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
||||||
|
|
||||||
pfix6 := netip.MustParsePrefix("fd11::11/128")
|
pfix6 := netip.MustParsePrefix("fd11::11/128")
|
||||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "")
|
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
|
||||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "")
|
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
b.Run("fail on proto", func(b *testing.B) {
|
b.Run("fail on proto", func(b *testing.B) {
|
||||||
@@ -487,8 +453,6 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -514,7 +478,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c1 := cert.CachedCertificate{
|
c1 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -529,10 +493,10 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
peerCert: &c1,
|
peerCert: &c1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// h1/c1 lacks the proper groups
|
// h1/c1 lacks the proper groups
|
||||||
@@ -546,8 +510,6 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -579,7 +541,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c2 := cert.CachedCertificate{
|
c2 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -594,7 +556,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
|
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c3 := cert.CachedCertificate{
|
c3 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -609,11 +571,11 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
|
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// c1 should pass because host match
|
// c1 should pass because host match
|
||||||
@@ -627,7 +589,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
|
|
||||||
// Test a remote address match
|
// Test a remote address match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -635,8 +597,6 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -660,12 +620,12 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Test a remote address match
|
// Test a remote address match
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -673,8 +633,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -701,10 +659,10 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
@@ -717,7 +675,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
|
|
||||||
oldFw := fw
|
oldFw := fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
@@ -726,7 +684,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
|
|
||||||
oldFw = fw
|
oldFw = fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
@@ -738,8 +696,6 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
|
||||||
|
|
||||||
c := cert.CachedCertificate{
|
c := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -761,11 +717,11 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
|
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
|
||||||
@@ -959,28 +915,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
mf := &mockFirewall{}
|
mf := &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding udp rule
|
// Test adding udp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding icmp rule
|
// Test adding icmp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding any rule
|
// Test adding any rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with cidr
|
// Test adding rule with cidr
|
||||||
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
||||||
@@ -988,14 +944,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with local_cidr
|
// Test adding rule with local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with cidr ipv6
|
// Test adding rule with cidr ipv6
|
||||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
cidr6 := netip.MustParsePrefix("fd00::/8")
|
||||||
@@ -1003,75 +959,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any cidr
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
|
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
|
|
||||||
|
|
||||||
// Test adding rule with junk cidr
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
|
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
|
||||||
|
|
||||||
// Test adding rule with local_cidr ipv6
|
// Test adding rule with local_cidr ipv6
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any local_cidr
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
|
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
|
|
||||||
|
|
||||||
// Test adding rule with junk local_cidr
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
|
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_name
|
// Test adding rule with ca_name
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
|
||||||
|
|
||||||
// Test single group
|
// Test single group
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test single groups
|
// Test single groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test multiple AND groups
|
// Test multiple AND groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test Add error
|
// Test Add error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
@@ -1094,7 +1024,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
r, err := convertRule(l, c, "test", 1)
|
r, err := convertRule(l, c, "test", 1)
|
||||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"group1"}, r.Groups)
|
assert.Equal(t, "group1", r.Group)
|
||||||
|
|
||||||
// Ensure group array of > 1 is errord
|
// Ensure group array of > 1 is errord
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
@@ -1114,228 +1044,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
|
|
||||||
r, err = convertRule(l, c, "test", 1)
|
r, err = convertRule(l, c, "test", 1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"group1"}, r.Groups)
|
assert.Equal(t, "group1", r.Group)
|
||||||
}
|
|
||||||
|
|
||||||
func TestFirewall_convertRuleSanity(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
|
|
||||||
noWarningPlease := []map[string]any{
|
|
||||||
{"group": "group1"},
|
|
||||||
{"groups": []any{"group2"}},
|
|
||||||
{"host": "bob"},
|
|
||||||
{"cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "host": "bob"},
|
|
||||||
{"cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
}
|
|
||||||
for _, c := range noWarningPlease {
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
|
|
||||||
}
|
|
||||||
|
|
||||||
yesWarningPlease := []map[string]any{
|
|
||||||
{"group": "group1"},
|
|
||||||
{"groups": []any{"group2"}},
|
|
||||||
{"cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "host": "bob"},
|
|
||||||
{"cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
}
|
|
||||||
for _, c := range yesWarningPlease {
|
|
||||||
c["host"] = "any"
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = r.sanity()
|
|
||||||
require.Error(t, err, "I wanted a warning: %+v", c)
|
|
||||||
}
|
|
||||||
//reset the list
|
|
||||||
yesWarningPlease = []map[string]any{
|
|
||||||
{"group": "group1"},
|
|
||||||
{"groups": []any{"group2"}},
|
|
||||||
{"cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "host": "bob"},
|
|
||||||
{"cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
}
|
|
||||||
for _, c := range yesWarningPlease {
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
r.Groups = append(r.Groups, "any")
|
|
||||||
err = r.sanity()
|
|
||||||
require.Error(t, err, "I wanted a warning: %+v", c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type testcase struct {
|
|
||||||
h *HostInfo
|
|
||||||
p firewall.Packet
|
|
||||||
c cert.Certificate
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
|
||||||
t.Helper()
|
|
||||||
cp := cert.NewCAPool()
|
|
||||||
resetConntrack(fw)
|
|
||||||
err := fw.Drop(c.p, true, c.h, cp, nil)
|
|
||||||
if c.err == nil {
|
|
||||||
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
|
||||||
} else {
|
|
||||||
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
|
||||||
c1 := dummyCert{
|
|
||||||
name: "host1",
|
|
||||||
networks: theirPrefixes,
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
h := HostInfo{
|
|
||||||
ConnectionState: &ConnectionState{
|
|
||||||
peerCert: &cert.CachedCertificate{
|
|
||||||
Certificate: &c1,
|
|
||||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
|
|
||||||
}
|
|
||||||
for i := range theirPrefixes {
|
|
||||||
h.vpnAddrs[i] = theirPrefixes[i].Addr()
|
|
||||||
}
|
|
||||||
h.buildNetworks(setup.myVpnNetworksTable, &c1)
|
|
||||||
p := firewall.Packet{
|
|
||||||
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
|
||||||
RemoteAddr: theirPrefixes[0].Addr(),
|
|
||||||
LocalPort: 10,
|
|
||||||
RemotePort: 90,
|
|
||||||
Protocol: firewall.ProtoUDP,
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
return testcase{
|
|
||||||
h: &h,
|
|
||||||
p: p,
|
|
||||||
c: &c1,
|
|
||||||
err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type testsetup struct {
|
|
||||||
c dummyCert
|
|
||||||
myVpnNetworksTable *bart.Lite
|
|
||||||
fw *Firewall
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
|
||||||
c := dummyCert{
|
|
||||||
name: "me",
|
|
||||||
networks: myPrefixes,
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
|
|
||||||
return newSetupFromCert(t, l, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
for _, prefix := range c.Networks() {
|
|
||||||
myVpnNetworksTable.Insert(prefix)
|
|
||||||
}
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
|
||||||
|
|
||||||
return testsetup{
|
|
||||||
c: c,
|
|
||||||
fw: fw,
|
|
||||||
myVpnNetworksTable: myVpnNetworksTable,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
|
|
||||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
|
||||||
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
|
||||||
t.Run("allow inbound all matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("allow inbound local matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
|
|
||||||
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("block inbound remote mismatched", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
|
|
||||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("Block a vpn peer packet", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
twoPrefixes := []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
|
|
||||||
}
|
|
||||||
t.Run("allow inbound one matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, nil, twoPrefixes...)
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("block inbound multimismatch", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
|
|
||||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
|
|
||||||
tc := buildTestCase(setup2, nil, twoPrefixes...)
|
|
||||||
tc.p.RemoteAddr = twoPrefixes[1].Addr()
|
|
||||||
tc.Test(t, setup2.fw)
|
|
||||||
})
|
|
||||||
t.Run("allow inbound unsafe route", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
|
|
||||||
c := dummyCert{
|
|
||||||
name: "me",
|
|
||||||
networks: []netip.Prefix{myPrefix},
|
|
||||||
unsafeNetworks: []netip.Prefix{unsafePrefix},
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
unsafeSetup := newSetupFromCert(t, l, c)
|
|
||||||
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
|
|
||||||
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
|
||||||
tc.err = ErrNoMatchingRule
|
|
||||||
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
|
||||||
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
|
|
||||||
tc.err = nil
|
|
||||||
tc.Test(t, unsafeSetup.fw) //should pass
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type addRuleCall struct {
|
type addRuleCall struct {
|
||||||
@@ -1345,8 +1054,8 @@ type addRuleCall struct {
|
|||||||
endPort int32
|
endPort int32
|
||||||
groups []string
|
groups []string
|
||||||
host string
|
host string
|
||||||
ip string
|
ip netip.Prefix
|
||||||
localIp string
|
localIp netip.Prefix
|
||||||
caName string
|
caName string
|
||||||
caSha string
|
caSha string
|
||||||
}
|
}
|
||||||
@@ -1356,7 +1065,7 @@ type mockFirewall struct {
|
|||||||
nextCallReturn error
|
nextCallReturn error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
|
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
|
||||||
mf.lastCall = addRuleCall{
|
mf.lastCall = addRuleCall{
|
||||||
incoming: incoming,
|
incoming: incoming,
|
||||||
proto: proto,
|
proto: proto,
|
||||||
|
|||||||
16
go.mod
16
go.mod
@@ -6,9 +6,10 @@ require (
|
|||||||
dario.cat/mergo v1.0.2
|
dario.cat/mergo v1.0.2
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
||||||
github.com/armon/go-radix v1.0.0
|
github.com/armon/go-radix v1.0.0
|
||||||
|
github.com/cilium/ebpf v0.12.3
|
||||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||||
github.com/flynn/noise v1.1.0
|
github.com/flynn/noise v1.1.0
|
||||||
github.com/gaissmai/bart v0.26.0
|
github.com/gaissmai/bart v0.25.0
|
||||||
github.com/gogo/protobuf v1.3.2
|
github.com/gogo/protobuf v1.3.2
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/kardianos/service v1.2.4
|
github.com/kardianos/service v1.2.4
|
||||||
@@ -22,17 +23,16 @@ require (
|
|||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
go.yaml.in/yaml/v3 v3.0.4
|
golang.org/x/crypto v0.43.0
|
||||||
golang.org/x/crypto v0.45.0
|
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||||
golang.org/x/net v0.47.0
|
golang.org/x/net v0.45.0
|
||||||
golang.org/x/sync v0.18.0
|
golang.org/x/sync v0.17.0
|
||||||
golang.org/x/sys v0.38.0
|
golang.org/x/sys v0.37.0
|
||||||
golang.org/x/term v0.37.0
|
golang.org/x/term v0.36.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/protobuf v1.36.10
|
google.golang.org/protobuf v1.36.8
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||||
)
|
)
|
||||||
|
|||||||
37
go.sum
37
go.sum
@@ -17,6 +17,8 @@ github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6r
|
|||||||
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
|
||||||
|
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
|
||||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
|
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
|
||||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
|
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
@@ -24,8 +26,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||||
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
|
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
|
||||||
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
|
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
|
||||||
|
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||||
@@ -78,8 +82,9 @@ github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn
|
|||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
|
||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
@@ -155,15 +160,13 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
|||||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
|
||||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||||
@@ -182,8 +185,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -191,8 +194,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@@ -209,11 +212,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
@@ -244,8 +247,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
|||||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
|||||||
118
handshake_ix.go
118
handshake_ix.go
@@ -2,6 +2,7 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
@@ -191,17 +192,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var vpnAddrs []netip.Addr
|
||||||
|
var filteredNetworks []netip.Prefix
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
certVersion := remoteCert.Certificate.Version()
|
certVersion := remoteCert.Certificate.Version()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
vpnNetworks := remoteCert.Certificate.Networks()
|
|
||||||
|
|
||||||
anyVpnAddrsInCommon := false
|
for _, network := range remoteCert.Certificate.Networks() {
|
||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
vpnAddr := network.Addr()
|
||||||
for i, network := range vpnNetworks {
|
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -209,10 +210,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrs[i] = network.Addr()
|
|
||||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||||
anyVpnAddrsInCommon = true
|
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filteredNetworks = append(filteredNetworks, network)
|
||||||
|
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnAddrs) == 0 {
|
||||||
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
@@ -249,30 +264,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgRxL := f.l.WithFields(m{
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
"vpnAddrs": vpnAddrs,
|
WithField("certName", certName).
|
||||||
"udpAddr": addr,
|
WithField("certVersion", certVersion).
|
||||||
"certName": certName,
|
WithField("fingerprint", fingerprint).
|
||||||
"certVersion": certVersion,
|
WithField("issuer", issuer).
|
||||||
"fingerprint": fingerprint,
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
"issuer": issuer,
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
Info("Handshake message received")
|
||||||
"responderIndex": hs.Details.ResponderIndex,
|
|
||||||
"remoteIndex": h.RemoteIndex,
|
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
|
||||||
})
|
|
||||||
|
|
||||||
if anyVpnAddrsInCommon {
|
|
||||||
msgRxL.Info("Handshake message received")
|
|
||||||
} else {
|
|
||||||
//todo warn if not lighthouse or relay?
|
|
||||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
WithField("certVersion", ci.myCert.Version()).
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -330,7 +341,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -571,22 +582,31 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
correctHostResponded := false
|
var vpnAddrs []netip.Addr
|
||||||
anyVpnAddrsInCommon := false
|
var filteredNetworks []netip.Prefix
|
||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
for _, network := range vpnNetworks {
|
||||||
for i, network := range vpnNetworks {
|
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||||
vpnAddrs[i] = network.Addr()
|
vpnAddr := network.Addr()
|
||||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
anyVpnAddrsInCommon = true
|
continue
|
||||||
}
|
|
||||||
if hostinfo.vpnAddrs[0] == network.Addr() {
|
|
||||||
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
|
||||||
correctHostResponded = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filteredNetworks = append(filteredNetworks, network)
|
||||||
|
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnAddrs) == 0 {
|
||||||
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if !correctHostResponded {
|
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("udpAddr", addr).
|
WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
@@ -598,7 +618,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
// Create a new hostinfo/handshake for the intended vpn ip
|
||||||
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
|
||||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
@@ -625,7 +644,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -633,17 +652,12 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithField("durationNs", duration).
|
WithField("durationNs", duration).
|
||||||
WithField("sentCachedPackets", len(hh.packetStore))
|
WithField("sentCachedPackets", len(hh.packetStore)).
|
||||||
if anyVpnAddrsInCommon {
|
Info("Handshake message received")
|
||||||
msgRxL.Info("Handshake message received")
|
|
||||||
} else {
|
|
||||||
//todo warn if not lighthouse or relay?
|
|
||||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build up the radix for the firewall if we have subnets in the cert
|
// Build up the radix for the firewall if we have subnets in the cert
|
||||||
hostinfo.vpnAddrs = vpnAddrs
|
hostinfo.vpnAddrs = vpnAddrs
|
||||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
|
|||||||
@@ -269,12 +269,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
||||||
// Send a RelayRequest to all known Relay IP's
|
// Send a RelayRequest to all known Relay IP's
|
||||||
for _, relay := range hostinfo.remotes.relays {
|
for _, relay := range hostinfo.remotes.relays {
|
||||||
// Don't relay through the host I'm trying to connect to
|
// Don't relay to myself
|
||||||
if relay == vpnIp {
|
if relay == vpnIp {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't relay to myself
|
// Don't relay through the host I'm trying to connect to
|
||||||
if hm.f.myVpnAddrsTable.Contains(relay) {
|
if hm.f.myVpnAddrsTable.Contains(relay) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
40
hostmap.go
40
hostmap.go
@@ -212,18 +212,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
|||||||
rs.relayForByIdx[idx] = r
|
rs.relayForByIdx[idx] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetworkType uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
NetworkTypeUnknown NetworkType = iota
|
|
||||||
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
|
||||||
NetworkTypeVPN
|
|
||||||
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
|
||||||
NetworkTypeVPNPeer
|
|
||||||
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
|
||||||
NetworkTypeUnsafe
|
|
||||||
)
|
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
remote netip.AddrPort
|
remote netip.AddrPort
|
||||||
remotes *RemoteList
|
remotes *RemoteList
|
||||||
@@ -237,8 +225,8 @@ type HostInfo struct {
|
|||||||
// vpn networks but were removed because they are not usable
|
// vpn networks but were removed because they are not usable
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
|
||||||
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
// networks are both all vpn and unsafe networks assigned to this host
|
||||||
networks *bart.Table[NetworkType]
|
networks *bart.Lite
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
@@ -742,26 +730,20 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
|
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||||
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
|
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||||
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
|
// Simple case, no CIDRTree needed
|
||||||
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
|
return
|
||||||
return // Simple case, no BART needed
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
i.networks = new(bart.Table[NetworkType])
|
i.networks = new(bart.Lite)
|
||||||
for _, network := range c.Networks() {
|
for _, network := range networks {
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||||
if myVpnNetworksTable.Contains(network.Addr()) {
|
i.networks.Insert(nprefix)
|
||||||
i.networks.Insert(nprefix, NetworkTypeVPN)
|
|
||||||
} else {
|
|
||||||
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range c.UnsafeNetworks() {
|
for _, network := range unsafeNetworks {
|
||||||
i.networks.Insert(network, NetworkTypeUnsafe)
|
i.networks.Insert(network)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
119
inside.go
119
inside.go
@@ -2,16 +2,18 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache *firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
@@ -120,10 +122,9 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
|
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
||||||
// it does not check if it is within our vpn networks!
|
|
||||||
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||||
f.handshakeManager.GetOrHandshake(vpnAddr, nil)
|
f.getOrHandshakeNoRouting(vpnAddr, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
||||||
@@ -139,6 +140,7 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
|
|||||||
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
||||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
||||||
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||||
|
|
||||||
destinationAddr := fwPacket.RemoteAddr
|
destinationAddr := fwPacket.RemoteAddr
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
||||||
@@ -231,10 +233,9 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
|
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||||
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
|
|
||||||
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||||
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -336,9 +337,21 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
if ci.eKey == nil {
|
if ci.eKey == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
target := remote
|
||||||
|
if !target.IsValid() {
|
||||||
|
target = hostinfo.remote
|
||||||
|
}
|
||||||
|
useRelay := !target.IsValid()
|
||||||
fullOut := out
|
fullOut := out
|
||||||
|
|
||||||
|
var pkt *overlay.Packet
|
||||||
|
if !useRelay && f.batches.Enabled() {
|
||||||
|
pkt = f.batches.newPacket()
|
||||||
|
if pkt != nil {
|
||||||
|
out = pkt.Payload()[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if useRelay {
|
if useRelay {
|
||||||
if len(out) < header.Len {
|
if len(out) < header.Len {
|
||||||
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
||||||
@@ -372,41 +385,85 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
if len(p) > 0 && slicesOverlap(out, p) {
|
||||||
|
tmp := make([]byte, len(p))
|
||||||
|
copy(tmp, p)
|
||||||
|
p = tmp
|
||||||
|
}
|
||||||
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
|
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
|
||||||
if noiseutil.EncryptLockNeeded {
|
if noiseutil.EncryptLockNeeded {
|
||||||
ci.writeLock.Unlock()
|
ci.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).WithError(err).
|
||||||
WithField("udpAddr", remote).WithField("counter", c).
|
WithField("udpAddr", target).WithField("counter", c).
|
||||||
WithField("attemptedCounter", c).
|
WithField("attemptedCounter", c).
|
||||||
Error("Failed to encrypt outgoing packet")
|
Error("Failed to encrypt outgoing packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if remote.IsValid() {
|
if target.IsValid() {
|
||||||
err = f.writers[q].WriteTo(out, remote)
|
if pkt != nil {
|
||||||
if err != nil {
|
pkt.Len = len(out)
|
||||||
hostinfo.logger(f.l).WithError(err).
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
f.l.WithFields(logrus.Fields{
|
||||||
}
|
"queue": q,
|
||||||
} else if hostinfo.remote.IsValid() {
|
"dest": target,
|
||||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
"payload_len": pkt.Len,
|
||||||
if err != nil {
|
"use_batches": true,
|
||||||
hostinfo.logger(f.l).WithError(err).
|
"remote_index": hostinfo.remoteIndexId,
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
}).Debug("enqueueing packet to UDP batch queue")
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Try to send via a relay
|
|
||||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
|
||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.relayState.DeleteRelay(relayIP)
|
|
||||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
if f.tryQueuePacket(q, pkt, target) {
|
||||||
break
|
return
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"queue": q,
|
||||||
|
"dest": target,
|
||||||
|
}).Debug("failed to enqueue packet; falling back to immediate send")
|
||||||
|
}
|
||||||
|
f.writeImmediatePacket(q, pkt, target, hostinfo)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
if f.tryQueueDatagram(q, out, target) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.writeImmediate(q, out, target, hostinfo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// fall back to relay path
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to send via a relay
|
||||||
|
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||||
|
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.relayState.DeleteRelay(relayIP)
|
||||||
|
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// slicesOverlap reports whether the two byte slices share any portion of memory.
|
||||||
|
// cipher.AEAD.Seal requires plaintext and dst to live in disjoint regions.
|
||||||
|
func slicesOverlap(a, b []byte) bool {
|
||||||
|
if len(a) == 0 || len(b) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
aStart := uintptr(unsafe.Pointer(&a[0]))
|
||||||
|
aEnd := aStart + uintptr(len(a))
|
||||||
|
bStart := uintptr(unsafe.Pointer(&b[0]))
|
||||||
|
bEnd := bStart + uintptr(len(b))
|
||||||
|
return aStart < bEnd && bStart < aEnd
|
||||||
|
}
|
||||||
|
|||||||
703
interface.go
703
interface.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,7 +22,13 @@ import (
|
|||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const (
|
||||||
|
mtu = 9001
|
||||||
|
defaultGSOFlushInterval = 150 * time.Microsecond
|
||||||
|
defaultBatchQueueDepthFactor = 4
|
||||||
|
defaultGSOMaxSegments = 8
|
||||||
|
maxKernelGSOSegments = 64
|
||||||
|
)
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
@@ -36,6 +43,9 @@ type InterfaceConfig struct {
|
|||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
DropLocalBroadcast bool
|
DropLocalBroadcast bool
|
||||||
DropMulticast bool
|
DropMulticast bool
|
||||||
|
EnableGSO bool
|
||||||
|
EnableGRO bool
|
||||||
|
GSOMaxSegments int
|
||||||
routines int
|
routines int
|
||||||
MessageMetrics *MessageMetrics
|
MessageMetrics *MessageMetrics
|
||||||
version string
|
version string
|
||||||
@@ -47,6 +57,8 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
|
BatchFlushInterval time.Duration
|
||||||
|
BatchQueueDepth int
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,9 +96,20 @@ type Interface struct {
|
|||||||
version string
|
version string
|
||||||
|
|
||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
|
batchQueueDepth int
|
||||||
|
enableGSO bool
|
||||||
|
enableGRO bool
|
||||||
|
gsoMaxSegments int
|
||||||
|
batchUDPQueueGauge metrics.Gauge
|
||||||
|
batchUDPFlushCounter metrics.Counter
|
||||||
|
batchTunQueueGauge metrics.Gauge
|
||||||
|
batchTunFlushCounter metrics.Counter
|
||||||
|
batchFlushInterval atomic.Int64
|
||||||
|
sendSem chan struct{}
|
||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
|
batches batchPipelines
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
@@ -161,6 +184,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
return nil, errors.New("no connection manager")
|
return nil, errors.New("no connection manager")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.GSOMaxSegments <= 0 {
|
||||||
|
c.GSOMaxSegments = defaultGSOMaxSegments
|
||||||
|
}
|
||||||
|
if c.GSOMaxSegments > maxKernelGSOSegments {
|
||||||
|
c.GSOMaxSegments = maxKernelGSOSegments
|
||||||
|
}
|
||||||
|
if c.BatchQueueDepth <= 0 {
|
||||||
|
c.BatchQueueDepth = c.routines * defaultBatchQueueDepthFactor
|
||||||
|
}
|
||||||
|
if c.BatchFlushInterval < 0 {
|
||||||
|
c.BatchFlushInterval = 0
|
||||||
|
}
|
||||||
|
if c.BatchFlushInterval == 0 && c.EnableGSO {
|
||||||
|
c.BatchFlushInterval = defaultGSOFlushInterval
|
||||||
|
}
|
||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
@@ -186,6 +225,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
relayManager: c.relayManager,
|
relayManager: c.relayManager,
|
||||||
connectionManager: c.connectionManager,
|
connectionManager: c.connectionManager,
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
|
batchQueueDepth: c.BatchQueueDepth,
|
||||||
|
enableGSO: c.EnableGSO,
|
||||||
|
enableGRO: c.EnableGRO,
|
||||||
|
gsoMaxSegments: c.GSOMaxSegments,
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
messageMetrics: c.MessageMetrics,
|
messageMetrics: c.MessageMetrics,
|
||||||
@@ -198,8 +241,25 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
|
ifce.batchUDPQueueGauge = metrics.GetOrRegisterGauge("batch.udp.queue_depth", nil)
|
||||||
|
ifce.batchUDPFlushCounter = metrics.GetOrRegisterCounter("batch.udp.flushes", nil)
|
||||||
|
ifce.batchTunQueueGauge = metrics.GetOrRegisterGauge("batch.tun.queue_depth", nil)
|
||||||
|
ifce.batchTunFlushCounter = metrics.GetOrRegisterCounter("batch.tun.flushes", nil)
|
||||||
|
ifce.batchFlushInterval.Store(int64(c.BatchFlushInterval))
|
||||||
|
ifce.sendSem = make(chan struct{}, c.routines)
|
||||||
|
ifce.batches.init(c.Inside, c.routines, c.BatchQueueDepth, c.GSOMaxSegments)
|
||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||||
|
if c.l.Level >= logrus.DebugLevel {
|
||||||
|
c.l.WithFields(logrus.Fields{
|
||||||
|
"enableGSO": c.EnableGSO,
|
||||||
|
"enableGRO": c.EnableGRO,
|
||||||
|
"gsoMaxSegments": c.GSOMaxSegments,
|
||||||
|
"batchQueueDepth": c.BatchQueueDepth,
|
||||||
|
"batchFlush": c.BatchFlushInterval,
|
||||||
|
"batching": ifce.batches.Enabled(),
|
||||||
|
}).Debug("initialized batch pipelines")
|
||||||
|
}
|
||||||
|
|
||||||
ifce.connectionManager.intf = ifce
|
ifce.connectionManager.intf = ifce
|
||||||
|
|
||||||
@@ -222,13 +282,6 @@ func (f *Interface) activate() {
|
|||||||
WithField("boringcrypto", boringEnabled()).
|
WithField("boringcrypto", boringEnabled()).
|
||||||
Info("Nebula interface is active")
|
Info("Nebula interface is active")
|
||||||
|
|
||||||
if f.routines > 1 {
|
|
||||||
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
|
||||||
f.routines = 1
|
|
||||||
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||||
|
|
||||||
// Prepare n tun queues
|
// Prepare n tun queues
|
||||||
@@ -255,6 +308,18 @@ func (f *Interface) run() {
|
|||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("batching", f.batches.Enabled()).Debug("starting interface run loops")
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.batches.Enabled() {
|
||||||
|
for i := 0; i < f.routines; i++ {
|
||||||
|
go f.runInsideBatchWorker(i)
|
||||||
|
go f.runTunWriteQueue(i)
|
||||||
|
go f.runSendQueue(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenIn(f.readers[i], i)
|
go f.listenIn(f.readers[i], i)
|
||||||
@@ -286,6 +351,17 @@ func (f *Interface) listenOut(i int) {
|
|||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
|
if f.batches.Enabled() {
|
||||||
|
if br, ok := reader.(overlay.BatchReader); ok {
|
||||||
|
f.listenInBatchLocked(reader, br, i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.listenInLegacyLocked(reader, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) listenInLegacyLocked(reader io.ReadWriteCloser, i int) {
|
||||||
packet := make([]byte, mtu)
|
packet := make([]byte, mtu)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
@@ -309,6 +385,581 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Interface) listenInBatchLocked(raw io.ReadWriteCloser, reader overlay.BatchReader, i int) {
|
||||||
|
pool := f.batches.Pool()
|
||||||
|
if pool == nil {
|
||||||
|
f.l.Warn("batch pipeline enabled without an allocated pool; falling back to single-packet reads")
|
||||||
|
f.listenInLegacyLocked(raw, i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
packets, err := reader.ReadIntoBatch(pool)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isVirtioHeadroomError(err) {
|
||||||
|
f.l.WithError(err).Warn("Batch reader fell back due to tun headroom issue")
|
||||||
|
f.listenInLegacyLocked(raw, i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.l.WithError(err).Error("Error while reading outbound packet batch")
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packets) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pkt := range packets {
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !f.batches.enqueueRx(i, pkt) {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) runInsideBatchWorker(i int) {
|
||||||
|
queue := f.batches.rxQueue(i)
|
||||||
|
if queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]byte, mtu)
|
||||||
|
fwPacket := &firewall.Packet{}
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
|
for pkt := range queue {
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
f.consumeInsidePacket(pkt.Payload(), fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) runSendQueue(i int) {
|
||||||
|
queue := f.batches.txQueue(i)
|
||||||
|
if queue == nil {
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("queue", i).Debug("tx queue not initialized; batching disabled for writer")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writer := f.writerForIndex(i)
|
||||||
|
if writer == nil {
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("queue", i).Debug("no UDP writer for batch queue")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("queue", i).Debug("send queue worker started")
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if f.l.Level >= logrus.WarnLevel {
|
||||||
|
f.l.WithField("queue", i).Warn("send queue worker exited")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
batchCap := f.batches.batchSizeHint()
|
||||||
|
if batchCap <= 0 {
|
||||||
|
batchCap = 1
|
||||||
|
}
|
||||||
|
gsoLimit := f.effectiveGSOMaxSegments()
|
||||||
|
if gsoLimit > batchCap {
|
||||||
|
batchCap = gsoLimit
|
||||||
|
}
|
||||||
|
pending := make([]queuedDatagram, 0, batchCap)
|
||||||
|
var (
|
||||||
|
flushTimer *time.Timer
|
||||||
|
flushC <-chan time.Time
|
||||||
|
)
|
||||||
|
dispatch := func(reason string, timerFired bool) {
|
||||||
|
if len(pending) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
batch := pending
|
||||||
|
f.flushAndReleaseBatch(i, writer, batch, reason)
|
||||||
|
for idx := range batch {
|
||||||
|
batch[idx] = queuedDatagram{}
|
||||||
|
}
|
||||||
|
pending = pending[:0]
|
||||||
|
if flushTimer != nil {
|
||||||
|
if !timerFired {
|
||||||
|
if !flushTimer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-flushTimer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushTimer = nil
|
||||||
|
flushC = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
armTimer := func() {
|
||||||
|
delay := f.currentBatchFlushInterval()
|
||||||
|
if delay <= 0 {
|
||||||
|
dispatch("nogso", false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if flushTimer == nil {
|
||||||
|
flushTimer = time.NewTimer(delay)
|
||||||
|
flushC = flushTimer.C
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case d := <-queue:
|
||||||
|
if d.packet == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"queue": i,
|
||||||
|
"payload_len": d.packet.Len,
|
||||||
|
"dest": d.addr,
|
||||||
|
}).Debug("send queue received packet")
|
||||||
|
}
|
||||||
|
pending = append(pending, d)
|
||||||
|
if gsoLimit > 0 && len(pending) >= gsoLimit {
|
||||||
|
dispatch("gso", false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(pending) >= cap(pending) {
|
||||||
|
dispatch("cap", false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
armTimer()
|
||||||
|
f.observeUDPQueueLen(i)
|
||||||
|
case <-flushC:
|
||||||
|
dispatch("timer", true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) runTunWriteQueue(i int) {
|
||||||
|
queue := f.batches.tunQueue(i)
|
||||||
|
if queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writer := f.batches.inside
|
||||||
|
if writer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requiredHeadroom := writer.BatchHeadroom()
|
||||||
|
|
||||||
|
batchCap := f.batches.batchSizeHint()
|
||||||
|
if batchCap <= 0 {
|
||||||
|
batchCap = 1
|
||||||
|
}
|
||||||
|
pending := make([]*overlay.Packet, 0, batchCap)
|
||||||
|
var (
|
||||||
|
flushTimer *time.Timer
|
||||||
|
flushC <-chan time.Time
|
||||||
|
)
|
||||||
|
flush := func(reason string, timerFired bool) {
|
||||||
|
if len(pending) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
valid := pending[:0]
|
||||||
|
for idx := range pending {
|
||||||
|
if !f.ensurePacketHeadroom(&pending[idx], requiredHeadroom, i, reason) {
|
||||||
|
pending[idx] = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pending[idx] != nil {
|
||||||
|
valid = append(valid, pending[idx])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(valid) > 0 {
|
||||||
|
if _, err := writer.WriteBatch(valid); err != nil {
|
||||||
|
f.l.WithError(err).
|
||||||
|
WithField("queue", i).
|
||||||
|
WithField("reason", reason).
|
||||||
|
Warn("Failed to write tun batch")
|
||||||
|
for _, pkt := range valid {
|
||||||
|
if pkt != nil {
|
||||||
|
f.writePacketToTun(i, pkt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pending = pending[:0]
|
||||||
|
if flushTimer != nil {
|
||||||
|
if !timerFired {
|
||||||
|
if !flushTimer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-flushTimer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushTimer = nil
|
||||||
|
flushC = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
armTimer := func() {
|
||||||
|
delay := f.currentBatchFlushInterval()
|
||||||
|
if delay <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if flushTimer == nil {
|
||||||
|
flushTimer = time.NewTimer(delay)
|
||||||
|
flushC = flushTimer.C
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case pkt := <-queue:
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.ensurePacketHeadroom(&pkt, requiredHeadroom, i, "queue") {
|
||||||
|
pending = append(pending, pkt)
|
||||||
|
}
|
||||||
|
if len(pending) >= cap(pending) {
|
||||||
|
flush("cap", false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
armTimer()
|
||||||
|
f.observeTunQueueLen(i)
|
||||||
|
case <-flushC:
|
||||||
|
flush("timer", true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) flushAndReleaseBatch(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
|
||||||
|
if len(batch) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.flushDatagrams(index, writer, batch, reason)
|
||||||
|
for idx := range batch {
|
||||||
|
if batch[idx].packet != nil {
|
||||||
|
batch[idx].packet.Release()
|
||||||
|
batch[idx].packet = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if f.batchUDPFlushCounter != nil {
|
||||||
|
f.batchUDPFlushCounter.Inc(int64(len(batch)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) flushDatagrams(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
|
||||||
|
if len(batch) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"writer": index,
|
||||||
|
"reason": reason,
|
||||||
|
"pending": len(batch),
|
||||||
|
}).Debug("udp batch flush summary")
|
||||||
|
}
|
||||||
|
maxSeg := f.effectiveGSOMaxSegments()
|
||||||
|
if bw, ok := writer.(udp.BatchConn); ok {
|
||||||
|
chunkCap := maxSeg
|
||||||
|
if chunkCap <= 0 {
|
||||||
|
chunkCap = len(batch)
|
||||||
|
}
|
||||||
|
chunk := make([]udp.Datagram, 0, chunkCap)
|
||||||
|
var (
|
||||||
|
currentAddr netip.AddrPort
|
||||||
|
segments int
|
||||||
|
)
|
||||||
|
flushChunk := func() {
|
||||||
|
if len(chunk) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"writer": index,
|
||||||
|
"segments": len(chunk),
|
||||||
|
"dest": chunk[0].Addr,
|
||||||
|
"reason": reason,
|
||||||
|
"pending_total": len(batch),
|
||||||
|
}).Debug("flushing UDP batch")
|
||||||
|
}
|
||||||
|
if err := bw.WriteBatch(chunk); err != nil {
|
||||||
|
f.l.WithError(err).
|
||||||
|
WithField("writer", index).
|
||||||
|
WithField("reason", reason).
|
||||||
|
Warn("Failed to write UDP batch")
|
||||||
|
}
|
||||||
|
chunk = chunk[:0]
|
||||||
|
segments = 0
|
||||||
|
}
|
||||||
|
for _, item := range batch {
|
||||||
|
if item.packet == nil || !item.addr.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := item.packet.Payload()[:item.packet.Len]
|
||||||
|
if segments == 0 {
|
||||||
|
currentAddr = item.addr
|
||||||
|
}
|
||||||
|
if item.addr != currentAddr || (maxSeg > 0 && segments >= maxSeg) {
|
||||||
|
flushChunk()
|
||||||
|
currentAddr = item.addr
|
||||||
|
}
|
||||||
|
chunk = append(chunk, udp.Datagram{Payload: payload, Addr: item.addr})
|
||||||
|
segments++
|
||||||
|
}
|
||||||
|
flushChunk()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, item := range batch {
|
||||||
|
if item.packet == nil || !item.addr.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"writer": index,
|
||||||
|
"reason": reason,
|
||||||
|
"dest": item.addr,
|
||||||
|
"segments": 1,
|
||||||
|
}).Debug("flushing UDP batch")
|
||||||
|
}
|
||||||
|
if err := writer.WriteTo(item.packet.Payload()[:item.packet.Len], item.addr); err != nil {
|
||||||
|
f.l.WithError(err).
|
||||||
|
WithField("writer", index).
|
||||||
|
WithField("udpAddr", item.addr).
|
||||||
|
WithField("reason", reason).
|
||||||
|
Warn("Failed to write UDP packet")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) tryQueueDatagram(q int, buf []byte, addr netip.AddrPort) bool {
|
||||||
|
if !addr.IsValid() || !f.batches.Enabled() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
pkt := f.batches.newPacket()
|
||||||
|
if pkt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
payload := pkt.Payload()
|
||||||
|
if len(payload) < len(buf) {
|
||||||
|
pkt.Release()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
copy(payload, buf)
|
||||||
|
pkt.Len = len(buf)
|
||||||
|
if f.batches.enqueueTx(q, pkt, addr) {
|
||||||
|
f.observeUDPQueueLen(q)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
pkt.Release()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) writerForIndex(i int) udp.Conn {
|
||||||
|
if i < 0 || i >= len(f.writers) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return f.writers[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) writeImmediate(q int, buf []byte, addr netip.AddrPort, hostinfo *HostInfo) {
|
||||||
|
writer := f.writerForIndex(q)
|
||||||
|
if writer == nil {
|
||||||
|
f.l.WithField("udpAddr", addr).
|
||||||
|
WithField("writer", q).
|
||||||
|
Error("Failed to write outgoing packet: no writer available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := writer.WriteTo(buf, addr); err != nil {
|
||||||
|
hostinfo.logger(f.l).
|
||||||
|
WithError(err).
|
||||||
|
WithField("udpAddr", addr).
|
||||||
|
Error("Failed to write outgoing packet")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) tryQueuePacket(q int, pkt *overlay.Packet, addr netip.AddrPort) bool {
|
||||||
|
if pkt == nil || !addr.IsValid() || !f.batches.Enabled() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.batches.enqueueTx(q, pkt, addr) {
|
||||||
|
f.observeUDPQueueLen(q)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) writeImmediatePacket(q int, pkt *overlay.Packet, addr netip.AddrPort, hostinfo *HostInfo) {
|
||||||
|
if pkt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writer := f.writerForIndex(q)
|
||||||
|
if writer == nil {
|
||||||
|
f.l.WithField("udpAddr", addr).
|
||||||
|
WithField("writer", q).
|
||||||
|
Error("Failed to write outgoing packet: no writer available")
|
||||||
|
pkt.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := writer.WriteTo(pkt.Payload()[:pkt.Len], addr); err != nil {
|
||||||
|
hostinfo.logger(f.l).
|
||||||
|
WithError(err).
|
||||||
|
WithField("udpAddr", addr).
|
||||||
|
Error("Failed to write outgoing packet")
|
||||||
|
}
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) writePacketToTun(q int, pkt *overlay.Packet) {
|
||||||
|
if pkt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writer := f.readers[q]
|
||||||
|
if writer == nil {
|
||||||
|
pkt.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if bw, ok := writer.(interface {
|
||||||
|
WriteBatch([]*overlay.Packet) (int, error)
|
||||||
|
}); ok {
|
||||||
|
if _, err := bw.WriteBatch([]*overlay.Packet{pkt}); err != nil {
|
||||||
|
f.l.WithError(err).WithField("queue", q).Warn("Failed to write tun packet via batch writer")
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := writer.Write(pkt.Payload()[:pkt.Len]); err != nil {
|
||||||
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
|
}
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) clonePacketWithHeadroom(pkt *overlay.Packet, required int) *overlay.Packet {
|
||||||
|
if pkt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload := pkt.Payload()[:pkt.Len]
|
||||||
|
if len(payload) == 0 && required <= 0 {
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := f.batches.Pool()
|
||||||
|
if pool != nil {
|
||||||
|
if clone := pool.Get(); clone != nil {
|
||||||
|
if len(clone.Payload()) >= len(payload) {
|
||||||
|
clone.Len = copy(clone.Payload(), payload)
|
||||||
|
pkt.Release()
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
clone.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if required < 0 {
|
||||||
|
required = 0
|
||||||
|
}
|
||||||
|
buf := make([]byte, required+len(payload))
|
||||||
|
n := copy(buf[required:], payload)
|
||||||
|
pkt.Release()
|
||||||
|
return &overlay.Packet{
|
||||||
|
Buf: buf,
|
||||||
|
Offset: required,
|
||||||
|
Len: n,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) observeUDPQueueLen(i int) {
|
||||||
|
if f.batchUDPQueueGauge == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.batchUDPQueueGauge.Update(int64(f.batches.txQueueLen(i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) observeTunQueueLen(i int) {
|
||||||
|
if f.batchTunQueueGauge == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.batchTunQueueGauge.Update(int64(f.batches.tunQueueLen(i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) currentBatchFlushInterval() time.Duration {
|
||||||
|
if v := f.batchFlushInterval.Load(); v > 0 {
|
||||||
|
return time.Duration(v)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) ensurePacketHeadroom(pkt **overlay.Packet, required int, queue int, reason string) bool {
|
||||||
|
p := *pkt
|
||||||
|
if p == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if required <= 0 || p.Offset >= required {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
clone := f.clonePacketWithHeadroom(p, required)
|
||||||
|
if clone == nil {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"queue": queue,
|
||||||
|
"reason": reason,
|
||||||
|
}).Warn("dropping packet lacking tun headroom")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
*pkt = clone
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isVirtioHeadroomError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := err.Error()
|
||||||
|
return strings.Contains(msg, "headroom") || strings.Contains(msg, "virtio")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) effectiveGSOMaxSegments() int {
|
||||||
|
max := f.gsoMaxSegments
|
||||||
|
if max <= 0 {
|
||||||
|
max = defaultGSOMaxSegments
|
||||||
|
}
|
||||||
|
if max > maxKernelGSOSegments {
|
||||||
|
max = maxKernelGSOSegments
|
||||||
|
}
|
||||||
|
if !f.enableGSO {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpOffloadConfigurator interface {
|
||||||
|
ConfigureOffload(enableGSO, enableGRO bool, maxSegments int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) applyOffloadConfig(enableGSO, enableGRO bool, maxSegments int) {
|
||||||
|
if maxSegments <= 0 {
|
||||||
|
maxSegments = defaultGSOMaxSegments
|
||||||
|
}
|
||||||
|
if maxSegments > maxKernelGSOSegments {
|
||||||
|
maxSegments = maxKernelGSOSegments
|
||||||
|
}
|
||||||
|
f.enableGSO = enableGSO
|
||||||
|
f.enableGRO = enableGRO
|
||||||
|
f.gsoMaxSegments = maxSegments
|
||||||
|
for _, writer := range f.writers {
|
||||||
|
if cfg, ok := writer.(udpOffloadConfigurator); ok {
|
||||||
|
cfg.ConfigureOffload(enableGSO, enableGRO, maxSegments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
c.RegisterReloadCallback(f.reloadFirewall)
|
c.RegisterReloadCallback(f.reloadFirewall)
|
||||||
c.RegisterReloadCallback(f.reloadSendRecvError)
|
c.RegisterReloadCallback(f.reloadSendRecvError)
|
||||||
@@ -411,6 +1062,42 @@ func (f *Interface) reloadMisc(c *config.C) {
|
|||||||
f.reQueryWait.Store(int64(n))
|
f.reQueryWait.Store(int64(n))
|
||||||
f.l.Info("timers.requery_wait_duration has changed")
|
f.l.Info("timers.requery_wait_duration has changed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.HasChanged("listen.gso_flush_timeout") {
|
||||||
|
d := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
|
||||||
|
if d < 0 {
|
||||||
|
d = 0
|
||||||
|
}
|
||||||
|
f.batchFlushInterval.Store(int64(d))
|
||||||
|
f.l.WithField("duration", d).Info("listen.gso_flush_timeout has changed")
|
||||||
|
} else if c.HasChanged("batch.flush_interval") {
|
||||||
|
d := c.GetDuration("batch.flush_interval", defaultGSOFlushInterval)
|
||||||
|
if d < 0 {
|
||||||
|
d = 0
|
||||||
|
}
|
||||||
|
f.batchFlushInterval.Store(int64(d))
|
||||||
|
f.l.WithField("duration", d).Warn("batch.flush_interval is deprecated; use listen.gso_flush_timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.HasChanged("batch.queue_depth") {
|
||||||
|
n := c.GetInt("batch.queue_depth", f.batchQueueDepth)
|
||||||
|
if n != f.batchQueueDepth {
|
||||||
|
f.batchQueueDepth = n
|
||||||
|
f.l.Warn("batch.queue_depth changes require a restart to take effect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.HasChanged("listen.enable_gso") || c.HasChanged("listen.enable_gro") || c.HasChanged("listen.gso_max_segments") {
|
||||||
|
enableGSO := c.GetBool("listen.enable_gso", f.enableGSO)
|
||||||
|
enableGRO := c.GetBool("listen.enable_gro", f.enableGRO)
|
||||||
|
maxSeg := c.GetInt("listen.gso_max_segments", f.gsoMaxSegments)
|
||||||
|
f.applyOffloadConfig(enableGSO, enableGRO, maxSeg)
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"enableGSO": enableGSO,
|
||||||
|
"enableGRO": enableGRO,
|
||||||
|
"gsoMaxSegments": maxSeg,
|
||||||
|
}).Info("listen GSO/GRO configuration updated")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||||
|
|||||||
@@ -360,8 +360,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||||
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
|
||||||
}
|
}
|
||||||
out[i] = addr
|
out[i] = addr
|
||||||
}
|
}
|
||||||
@@ -432,8 +431,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
|
||||||
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vals, ok := v.([]any)
|
vals, ok := v.([]any)
|
||||||
@@ -1339,19 +1337,12 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
|
||||||
for _, a := range n.Details.V4AddrPorts {
|
for _, a := range n.Details.V4AddrPorts {
|
||||||
b := protoV4AddrPortToNetAddrPort(a)
|
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
|
||||||
punch(b, detailsVpnAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range n.Details.V6AddrPorts {
|
for _, a := range n.Details.V6AddrPorts {
|
||||||
b := protoV6AddrPortToNetAddrPort(a)
|
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
|
||||||
punch(b, detailsVpnAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOldIPv4Only(t *testing.T) {
|
func TestOldIPv4Only(t *testing.T) {
|
||||||
|
|||||||
64
main.go
64
main.go
@@ -5,8 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime/debug"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -15,7 +14,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
@@ -29,10 +28,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if buildVersion == "" {
|
|
||||||
buildVersion = moduleVersion()
|
|
||||||
}
|
|
||||||
|
|
||||||
l := logger
|
l := logger
|
||||||
l.Formatter = &logrus.TextFormatter{
|
l.Formatter = &logrus.TextFormatter{
|
||||||
FullTimestamp: true,
|
FullTimestamp: true,
|
||||||
@@ -81,8 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
||||||
sshStart = nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,6 +144,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
// set up our UDP listener
|
// set up our UDP listener
|
||||||
udpConns := make([]udp.Conn, routines)
|
udpConns := make([]udp.Conn, routines)
|
||||||
port := c.GetInt("listen.port", 0)
|
port := c.GetInt("listen.port", 0)
|
||||||
|
enableGSO := c.GetBool("listen.enable_gso", true)
|
||||||
|
enableGRO := c.GetBool("listen.enable_gro", true)
|
||||||
|
gsoMaxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
||||||
|
if gsoMaxSegments <= 0 {
|
||||||
|
gsoMaxSegments = defaultGSOMaxSegments
|
||||||
|
}
|
||||||
|
if gsoMaxSegments > maxKernelGSOSegments {
|
||||||
|
gsoMaxSegments = maxKernelGSOSegments
|
||||||
|
}
|
||||||
|
gsoFlushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
|
||||||
|
if gsoFlushTimeout < 0 {
|
||||||
|
gsoFlushTimeout = 0
|
||||||
|
}
|
||||||
|
batchQueueDepth := c.GetInt("batch.queue_depth", 0)
|
||||||
|
|
||||||
if !configTest {
|
if !configTest {
|
||||||
rawListenHost := c.GetString("listen.host", "0.0.0.0")
|
rawListenHost := c.GetString("listen.host", "0.0.0.0")
|
||||||
@@ -169,13 +177,28 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
listenHost = ips[0].Unmap()
|
listenHost = ips[0].Unmap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
useWGDefault := runtime.GOOS == "linux"
|
||||||
|
useWG := c.GetBool("listen.use_wireguard_stack", useWGDefault)
|
||||||
|
var mkListener func(*logrus.Logger, netip.Addr, int, bool, int, int) (udp.Conn, error)
|
||||||
|
if useWG {
|
||||||
|
mkListener = udp.NewWireguardListener
|
||||||
|
} else {
|
||||||
|
mkListener = udp.NewListener
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64), i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
}
|
}
|
||||||
|
//todo set bpf on zeroth socket
|
||||||
udpServer.ReloadConfig(c)
|
udpServer.ReloadConfig(c)
|
||||||
|
if cfg, ok := udpServer.(interface {
|
||||||
|
ConfigureOffload(bool, bool, int)
|
||||||
|
}); ok {
|
||||||
|
cfg.ConfigureOffload(enableGSO, enableGRO, gsoMaxSegments)
|
||||||
|
}
|
||||||
udpConns[i] = udpServer
|
udpConns[i] = udpServer
|
||||||
|
|
||||||
// If port is dynamic, discover it before the next pass through the for loop
|
// If port is dynamic, discover it before the next pass through the for loop
|
||||||
@@ -243,12 +266,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
||||||
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
||||||
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
||||||
|
EnableGSO: enableGSO,
|
||||||
|
EnableGRO: enableGRO,
|
||||||
|
GSOMaxSegments: gsoMaxSegments,
|
||||||
routines: routines,
|
routines: routines,
|
||||||
MessageMetrics: messageMetrics,
|
MessageMetrics: messageMetrics,
|
||||||
version: buildVersion,
|
version: buildVersion,
|
||||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||||
punchy: punchy,
|
punchy: punchy,
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
|
BatchFlushInterval: gsoFlushTimeout,
|
||||||
|
BatchQueueDepth: batchQueueDepth,
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,6 +288,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifce.writers = udpConns
|
ifce.writers = udpConns
|
||||||
|
ifce.applyOffloadConfig(enableGSO, enableGRO, gsoMaxSegments)
|
||||||
lightHouse.ifce = ifce
|
lightHouse.ifce = ifce
|
||||||
|
|
||||||
ifce.RegisterConfigChangeCallbacks(c)
|
ifce.RegisterConfigChangeCallbacks(c)
|
||||||
@@ -302,18 +331,3 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
connManager.Start,
|
connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func moduleVersion() string {
|
|
||||||
info, ok := debug.ReadBuildInfo()
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dep := range info.Deps {
|
|
||||||
if dep.Path == "github.com/slackhq/nebula" {
|
|
||||||
return strings.TrimPrefix(dep.Version, "v")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|||||||
47
outside.go
47
outside.go
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/overlay"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,7 +20,7 @@ const (
|
|||||||
minFwPacketLen = 4
|
minFwPacketLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache *firewall.ConntrackCache) {
|
||||||
err := h.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
@@ -61,7 +62,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
switch h.Subtype {
|
switch h.Subtype {
|
||||||
case header.MessageNone:
|
case header.MessageNone:
|
||||||
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
|
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, ip, h.RemoteIndex) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case header.MessageRelay:
|
case header.MessageRelay:
|
||||||
@@ -465,23 +466,45 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache *firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool {
|
||||||
var err error
|
var (
|
||||||
|
err error
|
||||||
|
pkt *overlay.Packet
|
||||||
|
)
|
||||||
|
|
||||||
|
if f.batches.tunQueue(q) != nil {
|
||||||
|
pkt = f.batches.newPacket()
|
||||||
|
if pkt != nil {
|
||||||
|
out = pkt.Payload()[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
|
if addr.IsValid() {
|
||||||
|
f.maybeSendRecvError(addr, recvIndex)
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(out, true, fwPacket)
|
err = newPacket(out, true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||||
Warnf("Error while validating inbound packet")
|
Warnf("Error while validating inbound packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
Debugln("dropping out of window packet")
|
Debugln("dropping out of window packet")
|
||||||
return false
|
return false
|
||||||
@@ -489,6 +512,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||||
@@ -501,8 +527,17 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
_, err = f.readers[q].Write(out)
|
if pkt != nil {
|
||||||
if err != nil {
|
pkt.Len = len(out)
|
||||||
|
if f.batches.enqueueTun(q, pkt) {
|
||||||
|
f.observeTunQueueLen(q)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
f.writePacketToTun(q, pkt)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = f.readers[q].Write(out); err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
@@ -13,6 +14,86 @@ type Device interface {
|
|||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RoutesFor(netip.Addr) routing.Gateways
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
SupportsMultiqueue() bool
|
|
||||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Packet represents a single packet buffer with optional headroom to carry
|
||||||
|
// metadata (for example virtio-net headers).
|
||||||
|
type Packet struct {
|
||||||
|
Buf []byte
|
||||||
|
Offset int
|
||||||
|
Len int
|
||||||
|
release func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Payload() []byte {
|
||||||
|
return p.Buf[p.Offset : p.Offset+p.Len]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Reset() {
|
||||||
|
p.Len = 0
|
||||||
|
p.Offset = 0
|
||||||
|
p.release = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Release() {
|
||||||
|
if p.release != nil {
|
||||||
|
p.release()
|
||||||
|
p.release = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Capacity() int {
|
||||||
|
return len(p.Buf) - p.Offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// PacketPool manages reusable buffers with headroom.
|
||||||
|
type PacketPool struct {
|
||||||
|
headroom int
|
||||||
|
blksz int
|
||||||
|
pool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPacketPool(headroom, payload int) *PacketPool {
|
||||||
|
p := &PacketPool{headroom: headroom, blksz: headroom + payload}
|
||||||
|
p.pool.New = func() any {
|
||||||
|
buf := make([]byte, p.blksz)
|
||||||
|
return &Packet{Buf: buf, Offset: headroom}
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketPool) Get() *Packet {
|
||||||
|
pkt := p.pool.Get().(*Packet)
|
||||||
|
pkt.Offset = p.headroom
|
||||||
|
pkt.Len = 0
|
||||||
|
pkt.release = func() { p.put(pkt) }
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketPool) put(pkt *Packet) {
|
||||||
|
pkt.Reset()
|
||||||
|
p.pool.Put(pkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchReader allows reading multiple packets into a shared pool with
|
||||||
|
// preallocated headroom (e.g. virtio-net headers).
|
||||||
|
type BatchReader interface {
|
||||||
|
ReadIntoBatch(pool *PacketPool) ([]*Packet, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchWriter writes a slice of packets that carry their own metadata.
|
||||||
|
type BatchWriter interface {
|
||||||
|
WriteBatch(packets []*Packet) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchCapableDevice describes a device that can efficiently read and write
|
||||||
|
// batches of packets with virtio headroom.
|
||||||
|
type BatchCapableDevice interface {
|
||||||
|
Device
|
||||||
|
BatchReader
|
||||||
|
BatchWriter
|
||||||
|
BatchHeadroom() int
|
||||||
|
BatchPayloadCap() int
|
||||||
|
BatchSize() int
|
||||||
|
}
|
||||||
|
|||||||
@@ -95,10 +95,6 @@ func (t *tun) Name() string {
|
|||||||
return "android"
|
return "android"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -549,10 +549,6 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,10 +105,6 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) SupportsMultiqueue() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -450,10 +450,6 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -151,10 +151,6 @@ func (t *tun) Name() string {
|
|||||||
return "iOS"
|
return "iOS"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -19,6 +20,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
|
wgtun "github.com/slackhq/nebula/wgstack/tun"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
@@ -33,6 +35,7 @@ type tun struct {
|
|||||||
TXQueueLen int
|
TXQueueLen int
|
||||||
deviceIndex int
|
deviceIndex int
|
||||||
ioctlFd uintptr
|
ioctlFd uintptr
|
||||||
|
wgDevice wgtun.Device
|
||||||
|
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
@@ -68,7 +71,9 @@ type ifreqQLEN struct {
|
|||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
useWGDefault := runtime.GOOS == "linux"
|
||||||
|
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
|
||||||
|
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -113,7 +118,9 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
useWGDefault := runtime.GOOS == "linux"
|
||||||
|
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
|
||||||
|
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -123,16 +130,45 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, useWireguard bool) (*tun, error) {
|
||||||
|
var (
|
||||||
|
rw io.ReadWriteCloser = file
|
||||||
|
fd = int(file.Fd())
|
||||||
|
wgDev wgtun.Device
|
||||||
|
)
|
||||||
|
|
||||||
|
if useWireguard {
|
||||||
|
dev, err := wgtun.CreateTUNFromFile(file, c.GetInt("tun.mtu", DefaultMTU))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize wireguard tun device: %w", err)
|
||||||
|
}
|
||||||
|
wgDev = dev
|
||||||
|
rw = newWireguardTunIO(dev, c.GetInt("tun.mtu", DefaultMTU))
|
||||||
|
fd = int(dev.File().Fd())
|
||||||
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: rw,
|
||||||
fd: int(file.Fd()),
|
fd: fd,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
if wgDev != nil {
|
||||||
|
t.wgDevice = wgDev
|
||||||
|
}
|
||||||
|
if wgDev != nil {
|
||||||
|
// replace ioctl fd with device file descriptor to keep route management working
|
||||||
|
file = wgDev.File()
|
||||||
|
t.fd = int(file.Fd())
|
||||||
|
t.ioctlFd = file.Fd()
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.ioctlFd == 0 {
|
||||||
|
t.ioctlFd = file.Fd()
|
||||||
|
}
|
||||||
|
|
||||||
err := t.reload(c, true)
|
err := t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -216,10 +252,6 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -586,42 +618,48 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
||||||
|
|
||||||
var gateways routing.Gateways
|
var gateways routing.Gateways
|
||||||
|
|
||||||
link, err := netlink.LinkByName(t.Device)
|
link, err := netlink.LinkByName(t.Device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
|
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
|
||||||
return gateways
|
return gateways
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this route is relevant to our interface and there is a gateway then add it
|
// If this route is relevant to our interface and there is a gateway then add it
|
||||||
if r.LinkIndex == link.Attrs().Index {
|
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
||||||
gwAddr, ok := getGatewayAddr(r.Gw, r.Via)
|
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||||
if ok {
|
if !ok {
|
||||||
if t.isGatewayInVpnNetworks(gwAddr) {
|
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
|
||||||
} else {
|
|
||||||
// Gateway isn't in our overlay network, ignore
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
gwAddr = gwAddr.Unmap()
|
||||||
|
|
||||||
|
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
// Gateway isn't in our overlay network, ignore
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||||
|
} else {
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range r.MultiPath {
|
for _, p := range r.MultiPath {
|
||||||
// If this route is relevant to our interface and there is a gateway then add it
|
// If this route is relevant to our interface and there is a gateway then add it
|
||||||
if p.LinkIndex == link.Attrs().Index {
|
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
||||||
gwAddr, ok := getGatewayAddr(p.Gw, p.Via)
|
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
||||||
if ok {
|
if !ok {
|
||||||
if t.isGatewayInVpnNetworks(gwAddr) {
|
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
|
||||||
} else {
|
|
||||||
// Gateway isn't in our overlay network, ignore
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
gwAddr = gwAddr.Unmap()
|
||||||
|
|
||||||
|
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
// Gateway isn't in our overlay network, ignore
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||||
|
} else {
|
||||||
|
// p.Hops+1 = weight of the route
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -630,27 +668,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|||||||
return gateways
|
return gateways
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
|
|
||||||
// Try to use the old RTA_GATEWAY first
|
|
||||||
gwAddr, ok := netip.AddrFromSlice(gw)
|
|
||||||
if !ok {
|
|
||||||
// Fallback to the new RTA_VIA
|
|
||||||
rVia, ok := via.(*netlink.Via)
|
|
||||||
if ok {
|
|
||||||
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if gwAddr.IsValid() {
|
|
||||||
gwAddr = gwAddr.Unmap()
|
|
||||||
return gwAddr, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return netip.Addr{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||||
|
|
||||||
gateways := t.getGatewaysFromRoute(&r.Route)
|
gateways := t.getGatewaysFromRoute(&r.Route)
|
||||||
|
|
||||||
if len(gateways) == 0 {
|
if len(gateways) == 0 {
|
||||||
// No gateways relevant to our network, no routing changes required.
|
// No gateways relevant to our network, no routing changes required.
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
||||||
@@ -693,6 +714,14 @@ func (t *tun) Close() error {
|
|||||||
_ = t.ReadWriteCloser.Close()
|
_ = t.ReadWriteCloser.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if t.wgDevice != nil {
|
||||||
|
_ = t.wgDevice.Close()
|
||||||
|
if t.ioctlFd > 0 {
|
||||||
|
// underlying fd already closed by the device
|
||||||
|
t.ioctlFd = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if t.ioctlFd > 0 {
|
if t.ioctlFd > 0 {
|
||||||
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
||||||
}
|
}
|
||||||
|
|||||||
56
overlay/tun_linux_batch.go
Normal file
56
overlay/tun_linux_batch.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build linux && !android && !e2e_testing
|
||||||
|
|
||||||
|
package overlay
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func (t *tun) batchIO() (*wireguardTunIO, bool) {
|
||||||
|
io, ok := t.ReadWriteCloser.(*wireguardTunIO)
|
||||||
|
return io, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
||||||
|
io, ok := t.batchIO()
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("wireguard batch I/O not enabled")
|
||||||
|
}
|
||||||
|
return io.ReadIntoBatch(pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) WriteBatch(packets []*Packet) (int, error) {
|
||||||
|
io, ok := t.batchIO()
|
||||||
|
if ok {
|
||||||
|
return io.WriteBatch(packets)
|
||||||
|
}
|
||||||
|
for _, pkt := range packets {
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := t.Write(pkt.Payload()[:pkt.Len]); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
return len(packets), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchHeadroom() int {
|
||||||
|
if io, ok := t.batchIO(); ok {
|
||||||
|
return io.BatchHeadroom()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchPayloadCap() int {
|
||||||
|
if io, ok := t.batchIO(); ok {
|
||||||
|
return io.BatchPayloadCap()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchSize() int {
|
||||||
|
if io, ok := t.batchIO(); ok {
|
||||||
|
return io.BatchSize()
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
@@ -390,10 +390,6 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -310,10 +310,6 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,10 +132,6 @@ func (t *TestTun) Read(b []byte) (int, error) {
|
|||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -234,10 +234,6 @@ func (t *winTun) Write(b []byte) (int, error) {
|
|||||||
return t.tun.Write(b, 0)
|
return t.tun.Write(b, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,10 +46,6 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) SupportsMultiqueue() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|||||||
220
overlay/wireguard_tun_linux.go
Normal file
220
overlay/wireguard_tun_linux.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
//go:build linux && !android && !e2e_testing
|
||||||
|
|
||||||
|
package overlay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
wgtun "github.com/slackhq/nebula/wgstack/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type wireguardTunIO struct {
|
||||||
|
dev wgtun.Device
|
||||||
|
mtu int
|
||||||
|
batchSize int
|
||||||
|
|
||||||
|
readMu sync.Mutex
|
||||||
|
readBuffers [][]byte
|
||||||
|
readLens []int
|
||||||
|
legacyBuf []byte
|
||||||
|
|
||||||
|
writeMu sync.Mutex
|
||||||
|
writeBuf []byte
|
||||||
|
writeWrap [][]byte
|
||||||
|
writeBuffers [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
||||||
|
batch := dev.BatchSize()
|
||||||
|
if batch <= 0 {
|
||||||
|
batch = 1
|
||||||
|
}
|
||||||
|
if mtu <= 0 {
|
||||||
|
mtu = DefaultMTU
|
||||||
|
}
|
||||||
|
return &wireguardTunIO{
|
||||||
|
dev: dev,
|
||||||
|
mtu: mtu,
|
||||||
|
batchSize: batch,
|
||||||
|
readLens: make([]int, batch),
|
||||||
|
legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
||||||
|
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
||||||
|
writeWrap: make([][]byte, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) Read(p []byte) (int, error) {
|
||||||
|
w.readMu.Lock()
|
||||||
|
defer w.readMu.Unlock()
|
||||||
|
|
||||||
|
bufs := w.readBuffers
|
||||||
|
if len(bufs) == 0 {
|
||||||
|
bufs = [][]byte{w.legacyBuf}
|
||||||
|
w.readBuffers = bufs
|
||||||
|
}
|
||||||
|
n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
length := w.readLens[0]
|
||||||
|
copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
|
||||||
|
return length, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
||||||
|
if len(p) > w.mtu {
|
||||||
|
return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
|
||||||
|
}
|
||||||
|
w.writeMu.Lock()
|
||||||
|
defer w.writeMu.Unlock()
|
||||||
|
buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
|
||||||
|
for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
|
||||||
|
buf[i] = 0
|
||||||
|
}
|
||||||
|
copy(buf[wgtun.VirtioNetHdrLen:], p)
|
||||||
|
w.writeWrap[0] = buf
|
||||||
|
n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
||||||
|
if pool == nil {
|
||||||
|
return nil, fmt.Errorf("wireguard tun: packet pool is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.readMu.Lock()
|
||||||
|
defer w.readMu.Unlock()
|
||||||
|
|
||||||
|
if len(w.readBuffers) < w.batchSize {
|
||||||
|
w.readBuffers = make([][]byte, w.batchSize)
|
||||||
|
}
|
||||||
|
if len(w.readLens) < w.batchSize {
|
||||||
|
w.readLens = make([]int, w.batchSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
packets := make([]*Packet, w.batchSize)
|
||||||
|
requiredHeadroom := w.BatchHeadroom()
|
||||||
|
requiredPayload := w.BatchPayloadCap()
|
||||||
|
headroom := 0
|
||||||
|
for i := 0; i < w.batchSize; i++ {
|
||||||
|
pkt := pool.Get()
|
||||||
|
if pkt == nil {
|
||||||
|
releasePackets(packets[:i])
|
||||||
|
return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
|
||||||
|
}
|
||||||
|
if pkt.Capacity() < requiredPayload {
|
||||||
|
pkt.Release()
|
||||||
|
releasePackets(packets[:i])
|
||||||
|
return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
headroom = pkt.Offset
|
||||||
|
if headroom < requiredHeadroom {
|
||||||
|
pkt.Release()
|
||||||
|
releasePackets(packets[:i])
|
||||||
|
return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
|
||||||
|
}
|
||||||
|
} else if pkt.Offset != headroom {
|
||||||
|
pkt.Release()
|
||||||
|
releasePackets(packets[:i])
|
||||||
|
return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
|
||||||
|
}
|
||||||
|
packets[i] = pkt
|
||||||
|
w.readBuffers[i] = pkt.Buf
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
|
||||||
|
if err != nil {
|
||||||
|
releasePackets(packets)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
releasePackets(packets)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
packets[i].Len = w.readLens[i]
|
||||||
|
}
|
||||||
|
for i := n; i < w.batchSize; i++ {
|
||||||
|
packets[i].Release()
|
||||||
|
packets[i] = nil
|
||||||
|
}
|
||||||
|
return packets[:n], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
|
||||||
|
if len(packets) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
requiredHeadroom := w.BatchHeadroom()
|
||||||
|
offset := packets[0].Offset
|
||||||
|
if offset < requiredHeadroom {
|
||||||
|
releasePackets(packets)
|
||||||
|
return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
|
||||||
|
}
|
||||||
|
for _, pkt := range packets {
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pkt.Offset != offset {
|
||||||
|
releasePackets(packets)
|
||||||
|
return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
|
||||||
|
}
|
||||||
|
limit := pkt.Offset + pkt.Len
|
||||||
|
if limit > len(pkt.Buf) {
|
||||||
|
releasePackets(packets)
|
||||||
|
return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.writeMu.Lock()
|
||||||
|
defer w.writeMu.Unlock()
|
||||||
|
|
||||||
|
if len(w.writeBuffers) < len(packets) {
|
||||||
|
w.writeBuffers = make([][]byte, len(packets))
|
||||||
|
}
|
||||||
|
for i, pkt := range packets {
|
||||||
|
if pkt == nil {
|
||||||
|
w.writeBuffers[i] = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
limit := pkt.Offset + pkt.Len
|
||||||
|
w.writeBuffers[i] = pkt.Buf[:limit]
|
||||||
|
}
|
||||||
|
n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
releasePackets(packets)
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) BatchHeadroom() int {
|
||||||
|
return wgtun.VirtioNetHdrLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) BatchPayloadCap() int {
|
||||||
|
return w.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) BatchSize() int {
|
||||||
|
return w.batchSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func releasePackets(pkts []*Packet) {
|
||||||
|
for _, pkt := range pkts {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
10
pki.go
10
pki.go
@@ -523,13 +523,9 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
|||||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bl := c.GetStringSlice("pki.blocklist", []string{})
|
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
||||||
if len(bl) > 0 {
|
l.WithField("fingerprint", fp).Info("Blocklisting cert")
|
||||||
for _, fp := range bl {
|
caPool.BlocklistFingerprint(fp)
|
||||||
caPool.BlocklistFingerprint(fp)
|
|
||||||
}
|
|
||||||
|
|
||||||
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return caPool, nil
|
return caPool, nil
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"go.yaml.in/yaml/v3"
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|||||||
@@ -34,10 +34,6 @@ func (NoopTun) Write([]byte) (int, error) {
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoopTun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, errors.New("unsupported")
|
return nil, errors.New("unsupported")
|
||||||
}
|
}
|
||||||
|
|||||||
16
udp/conn.go
16
udp/conn.go
@@ -19,10 +19,21 @@ type Conn interface {
|
|||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
SupportsMultipleReaders() bool
|
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Datagram represents a UDP payload destined to a specific address.
|
||||||
|
type Datagram struct {
|
||||||
|
Payload []byte
|
||||||
|
Addr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchConn can send multiple datagrams in one syscall.
|
||||||
|
type BatchConn interface {
|
||||||
|
Conn
|
||||||
|
WriteBatch(pkts []Datagram) error
|
||||||
|
}
|
||||||
|
|
||||||
type NoopConn struct{}
|
type NoopConn struct{}
|
||||||
|
|
||||||
func (NoopConn) Rebind() error {
|
func (NoopConn) Rebind() error {
|
||||||
@@ -34,9 +45,6 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
func (NoopConn) ListenOut(_ EncReader) {
|
func (NoopConn) ListenOut(_ EncReader) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
func (NoopConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
|||||||
return ErrInvalidIPv6RemoteForSocket
|
return ErrInvalidIPv6RemoteForSocket
|
||||||
}
|
}
|
||||||
|
|
||||||
var rsa unix.RawSockaddrInet4
|
var rsa unix.RawSockaddrInet6
|
||||||
rsa.Family = unix.AF_INET
|
rsa.Family = unix.AF_INET6
|
||||||
rsa.Addr = ap.Addr().As4()
|
rsa.Addr = ap.Addr().As16()
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
||||||
sa = unsafe.Pointer(&rsa)
|
sa = unsafe.Pointer(&rsa)
|
||||||
addrLen = syscall.SizeofSockaddrInet4
|
addrLen = syscall.SizeofSockaddrInet4
|
||||||
@@ -184,10 +184,6 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
var err error
|
var err error
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
|
|||||||
@@ -85,7 +85,3 @@ func (u *GenericConn) ListenOut(r EncReader) {
|
|||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func maybeIPV4(ip net.IP) (net.IP, bool) {
|
|||||||
return ip, false
|
return ip, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) {
|
||||||
af := unix.AF_INET6
|
af := unix.AF_INET6
|
||||||
if ip.Is4() {
|
if ip.Is4() {
|
||||||
af = unix.AF_INET
|
af = unix.AF_INET
|
||||||
@@ -72,10 +72,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) SupportsMultipleReaders() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -314,31 +310,51 @@ func (u *StdConn) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||||
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
if len(udpConns) == 0 {
|
||||||
var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge
|
return func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
type statsProvider struct {
|
||||||
|
index int
|
||||||
|
conn *StdConn
|
||||||
|
}
|
||||||
|
|
||||||
|
providers := make([]statsProvider, 0, len(udpConns))
|
||||||
|
for i, c := range udpConns {
|
||||||
|
if sc, ok := c.(*StdConn); ok {
|
||||||
|
providers = append(providers, statsProvider{index: i, conn: sc})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
|
||||||
var meminfo [unix.SK_MEMINFO_VARS]uint32
|
var meminfo [unix.SK_MEMINFO_VARS]uint32
|
||||||
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
|
if err := providers[0].conn.getMemInfo(&meminfo); err != nil {
|
||||||
udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
|
return func() {}
|
||||||
for i := range udpConns {
|
}
|
||||||
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
|
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
|
udpGauges := make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(providers))
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
|
for i, provider := range providers {
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
|
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", provider.index), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", provider.index), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", provider.index), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", provider.index), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", provider.index), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", provider.index), nil),
|
||||||
}
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", provider.index), nil),
|
||||||
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", provider.index), nil),
|
||||||
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", provider.index), nil),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return func() {
|
return func() {
|
||||||
for i, gauges := range udpGauges {
|
for i, provider := range providers {
|
||||||
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
|
if err := provider.conn.getMemInfo(&meminfo); err == nil {
|
||||||
for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
|
for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
|
||||||
gauges[j].Update(int64(meminfo[j]))
|
udpGauges[i][j].Update(int64(meminfo[j]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -315,10 +315,6 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *RIOConn) Rebind() error {
|
func (u *RIOConn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -127,10 +127,6 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
return u.Addr, nil
|
return u.Addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *TesterConn) Rebind() error {
|
func (u *TesterConn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
226
udp/wireguard_conn_linux.go
Normal file
226
udp/wireguard_conn_linux.go
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
//go:build linux && !android && !e2e_testing
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
|
||||||
|
type WGConn struct {
|
||||||
|
l *logrus.Logger
|
||||||
|
bind *wgconn.StdNetBind
|
||||||
|
recvers []wgconn.ReceiveFunc
|
||||||
|
batch int
|
||||||
|
reqBatch int
|
||||||
|
localIP netip.Addr
|
||||||
|
localPort uint16
|
||||||
|
enableGSO bool
|
||||||
|
enableGRO bool
|
||||||
|
gsoMaxSeg int
|
||||||
|
closed atomic.Bool
|
||||||
|
q int
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
|
||||||
|
func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) {
|
||||||
|
bind := wgconn.NewStdNetBindForAddr(ip, multi, q)
|
||||||
|
recvers, actualPort, err := bind.Open(uint16(port))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if batch <= 0 {
|
||||||
|
batch = bind.BatchSize()
|
||||||
|
} else if batch > bind.BatchSize() {
|
||||||
|
batch = bind.BatchSize()
|
||||||
|
}
|
||||||
|
return &WGConn{
|
||||||
|
l: l,
|
||||||
|
bind: bind,
|
||||||
|
recvers: recvers,
|
||||||
|
batch: batch,
|
||||||
|
reqBatch: batch,
|
||||||
|
localIP: ip,
|
||||||
|
localPort: actualPort,
|
||||||
|
q: q,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) Rebind() error {
|
||||||
|
// WireGuard's bind does not support rebinding in place.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
|
if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
|
||||||
|
// Fallback to wildcard IPv4 for display purposes.
|
||||||
|
return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(c.localIP, c.localPort), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
|
||||||
|
batchSize := c.batch
|
||||||
|
packets := make([][]byte, batchSize)
|
||||||
|
for i := range packets {
|
||||||
|
packets[i] = make([]byte, 0xffff)
|
||||||
|
}
|
||||||
|
sizes := make([]int, batchSize)
|
||||||
|
endpoints := make([]wgconn.Endpoint, batchSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n, err := fn(packets, sizes, endpoints)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.l != nil {
|
||||||
|
c.l.WithError(err).Debug("wireguard UDP listener receive error")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
|
||||||
|
if !ok {
|
||||||
|
if c.l != nil {
|
||||||
|
c.l.Warn("wireguard UDP listener received unexpected endpoint type")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addr := stdEp.AddrPort
|
||||||
|
r(addr, packets[i][:sizes[i]])
|
||||||
|
endpoints[i] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) ListenOut(r EncReader) {
|
||||||
|
for _, fn := range c.recvers {
|
||||||
|
go c.listen(fn, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c.closed.Load() {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
ep := &wgconn.StdNetEndpoint{AddrPort: addr}
|
||||||
|
return c.bind.Send([][]byte{b}, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) WriteBatch(datagrams []Datagram) error {
|
||||||
|
if len(datagrams) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c.closed.Load() {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
max := c.batch
|
||||||
|
if max <= 0 {
|
||||||
|
max = len(datagrams)
|
||||||
|
if max == 0 {
|
||||||
|
max = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bufs := make([][]byte, 0, max)
|
||||||
|
var (
|
||||||
|
current netip.AddrPort
|
||||||
|
endpoint *wgconn.StdNetEndpoint
|
||||||
|
haveAddr bool
|
||||||
|
)
|
||||||
|
flush := func() error {
|
||||||
|
if len(bufs) == 0 || endpoint == nil {
|
||||||
|
bufs = bufs[:0]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := c.bind.Send(bufs, endpoint)
|
||||||
|
bufs = bufs[:0]
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, d := range datagrams {
|
||||||
|
if len(d.Payload) == 0 || !d.Addr.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !haveAddr || d.Addr != current {
|
||||||
|
if err := flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
current = d.Addr
|
||||||
|
endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
|
||||||
|
haveAddr = true
|
||||||
|
}
|
||||||
|
bufs = append(bufs, d.Payload)
|
||||||
|
if len(bufs) >= max {
|
||||||
|
if err := flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
|
||||||
|
c.enableGSO = enableGSO
|
||||||
|
c.enableGRO = enableGRO
|
||||||
|
if maxSegments <= 0 {
|
||||||
|
maxSegments = 1
|
||||||
|
} else if maxSegments > wgconn.IdealBatchSize {
|
||||||
|
maxSegments = wgconn.IdealBatchSize
|
||||||
|
}
|
||||||
|
c.gsoMaxSeg = maxSegments
|
||||||
|
|
||||||
|
effectiveBatch := c.reqBatch
|
||||||
|
if enableGSO && c.bind != nil {
|
||||||
|
bindBatch := c.bind.BatchSize()
|
||||||
|
if effectiveBatch < bindBatch {
|
||||||
|
if c.l != nil {
|
||||||
|
c.l.WithFields(logrus.Fields{
|
||||||
|
"requested": c.reqBatch,
|
||||||
|
"effective": bindBatch,
|
||||||
|
}).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
|
||||||
|
}
|
||||||
|
effectiveBatch = bindBatch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.batch = effectiveBatch
|
||||||
|
|
||||||
|
if c.l != nil {
|
||||||
|
c.l.WithFields(logrus.Fields{
|
||||||
|
"enableGSO": enableGSO,
|
||||||
|
"enableGRO": enableGRO,
|
||||||
|
"gsoMaxSegments": maxSegments,
|
||||||
|
}).Debug("configured wireguard UDP offload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) ReloadConfig(*config.C) {
|
||||||
|
// WireGuard bind currently does not expose runtime configuration knobs.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) Close() error {
|
||||||
|
var err error
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.closed.Store(true)
|
||||||
|
err = c.bind.Close()
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
15
udp/wireguard_conn_unsupported.go
Normal file
15
udp/wireguard_conn_unsupported.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !linux || android || e2e_testing
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWireguardListener is only available on Linux builds.
|
||||||
|
func NewWireguardListener(*logrus.Logger, netip.Addr, int, bool, int) (Conn, error) {
|
||||||
|
return nil, fmt.Errorf("wireguard experimental UDP listener is only supported on Linux")
|
||||||
|
}
|
||||||
587
wgstack/conn/bind_std.go
Normal file
587
wgstack/conn/bind_std.go
Normal file
@@ -0,0 +1,587 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Bind = (*StdNetBind)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
||||||
|
// (see bind_windows.go), it may fall back to StdNetBind.
|
||||||
|
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
||||||
|
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||||
|
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||||
|
type StdNetBind struct {
|
||||||
|
mu sync.Mutex // protects all fields except as specified
|
||||||
|
ipv4 *net.UDPConn
|
||||||
|
ipv6 *net.UDPConn
|
||||||
|
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||||
|
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||||
|
ipv4TxOffload bool
|
||||||
|
ipv4RxOffload bool
|
||||||
|
ipv6TxOffload bool
|
||||||
|
ipv6RxOffload bool
|
||||||
|
|
||||||
|
// these two fields are not guarded by mu
|
||||||
|
udpAddrPool sync.Pool
|
||||||
|
msgsPool sync.Pool
|
||||||
|
|
||||||
|
blackhole4 bool
|
||||||
|
blackhole6 bool
|
||||||
|
q int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStdNetBind creates a bind that listens on all interfaces.
|
||||||
|
func NewStdNetBind() *StdNetBind {
|
||||||
|
return newStdNetBind().(*StdNetBind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStdNetBindForAddr creates a bind that listens on a specific address.
|
||||||
|
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
|
||||||
|
// IPv6 socket will be created.
|
||||||
|
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool, q int) *StdNetBind {
|
||||||
|
b := NewStdNetBind()
|
||||||
|
b.q = q
|
||||||
|
//if addr.IsValid() {
|
||||||
|
// if addr.IsUnspecified() {
|
||||||
|
// // keep dual-stack defaults with empty listen addresses
|
||||||
|
// } else if addr.Is4() {
|
||||||
|
// b.listenAddr4 = addr.Unmap().String()
|
||||||
|
// b.bindV4 = true
|
||||||
|
// b.bindV6 = false
|
||||||
|
// } else {
|
||||||
|
// b.listenAddr6 = addr.Unmap().String()
|
||||||
|
// b.bindV6 = true
|
||||||
|
// b.bindV4 = false
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
//b.reusePort = reusePort
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStdNetBind() Bind {
|
||||||
|
return &StdNetBind{
|
||||||
|
udpAddrPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: make([]byte, 16),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
msgsPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
// ipv6.Message and ipv4.Message are interchangeable as they are
|
||||||
|
// both aliases for x/net/internal/socket.Message.
|
||||||
|
msgs := make([]ipv6.Message, IdealBatchSize)
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make(net.Buffers, 1)
|
||||||
|
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
|
||||||
|
}
|
||||||
|
return &msgs
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type StdNetEndpoint struct {
|
||||||
|
// AddrPort is the endpoint destination.
|
||||||
|
netip.AddrPort
|
||||||
|
// src is the current sticky source address and interface index, if
|
||||||
|
// supported. Typically this is a PKTINFO structure from/for control
|
||||||
|
// messages, see unix.PKTINFO for an example.
|
||||||
|
src []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Bind = (*StdNetBind)(nil)
|
||||||
|
_ Endpoint = &StdNetEndpoint{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||||
|
e, err := netip.ParseAddrPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &StdNetEndpoint{
|
||||||
|
AddrPort: e,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) ClearSrc() {
|
||||||
|
if e.src != nil {
|
||||||
|
// Truncate src, no need to reallocate.
|
||||||
|
e.src = e.src[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||||
|
return e.AddrPort.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||||
|
b, _ := e.AddrPort.MarshalBinary()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) DstToString() string {
|
||||||
|
return e.AddrPort.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenNet(network string, port int, q int) (*net.UDPConn, int, error) {
|
||||||
|
lc := listenConfig(q)
|
||||||
|
|
||||||
|
conn, err := lc.ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if q == 0 {
|
||||||
|
if EvilFdZero == 0 {
|
||||||
|
panic("fuck")
|
||||||
|
}
|
||||||
|
err = reusePortHax(EvilFdZero)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("reuse port hax: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve port.
|
||||||
|
laddr := conn.LocalAddr()
|
||||||
|
uaddr, err := net.ResolveUDPAddr(
|
||||||
|
laddr.Network(),
|
||||||
|
laddr.String(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var tries int
|
||||||
|
|
||||||
|
if s.ipv4 != nil || s.ipv6 != nil {
|
||||||
|
return nil, 0, ErrBindAlreadyOpen
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
||||||
|
// If uport is 0, we can retry on failure.
|
||||||
|
again:
|
||||||
|
port := int(uport)
|
||||||
|
var v4conn, v6conn *net.UDPConn
|
||||||
|
var v4pc *ipv4.PacketConn
|
||||||
|
var v6pc *ipv6.PacketConn
|
||||||
|
|
||||||
|
v4conn, port, err = listenNet("udp4", port, s.q)
|
||||||
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen on the same port as we're using for ipv4.
|
||||||
|
v6conn, port, err = listenNet("udp6", port, s.q)
|
||||||
|
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||||
|
v4conn.Close()
|
||||||
|
tries++
|
||||||
|
goto again
|
||||||
|
}
|
||||||
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
v4conn.Close()
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
var fns []ReceiveFunc
|
||||||
|
if v4conn != nil {
|
||||||
|
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
v4pc = ipv4.NewPacketConn(v4conn)
|
||||||
|
s.ipv4PC = v4pc
|
||||||
|
}
|
||||||
|
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
|
||||||
|
s.ipv4 = v4conn
|
||||||
|
}
|
||||||
|
if v6conn != nil {
|
||||||
|
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
v6pc = ipv6.NewPacketConn(v6conn)
|
||||||
|
s.ipv6PC = v6pc
|
||||||
|
}
|
||||||
|
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
|
||||||
|
s.ipv6 = v6conn
|
||||||
|
}
|
||||||
|
if len(fns) == 0 {
|
||||||
|
return nil, 0, syscall.EAFNOSUPPORT
|
||||||
|
}
|
||||||
|
|
||||||
|
return fns, uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
|
||||||
|
for i := range *msgs {
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||||
|
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||||
|
}
|
||||||
|
s.msgsPool.Put(msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) getMessages() *[]ipv6.Message {
|
||||||
|
return s.msgsPool.Get().(*[]ipv6.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// If compilation fails here these are no longer the same underlying type.
|
||||||
|
_ ipv6.Message = ipv4.Message{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type batchReader interface {
|
||||||
|
ReadBatch([]ipv6.Message, int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type batchWriter interface {
|
||||||
|
WriteBatch([]ipv6.Message, int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) receiveIP(
|
||||||
|
br batchReader,
|
||||||
|
conn *net.UDPConn,
|
||||||
|
rxOffload bool,
|
||||||
|
bufs [][]byte,
|
||||||
|
sizes []int,
|
||||||
|
eps []Endpoint,
|
||||||
|
) (n int, err error) {
|
||||||
|
msgs := s.getMessages()
|
||||||
|
for i := range bufs {
|
||||||
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||||
|
}
|
||||||
|
defer s.putMessages(msgs)
|
||||||
|
var numMsgs int
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
if rxOffload {
|
||||||
|
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
|
||||||
|
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
numMsgs, err = br.ReadBatch(*msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg := &(*msgs)[0]
|
||||||
|
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
msg := &(*msgs)[i]
|
||||||
|
sizes[i] = msg.N
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||||
|
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||||
|
eps[i] = ep
|
||||||
|
}
|
||||||
|
return numMsgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
|
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
|
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||||
|
// rename the IdealBatchSize constant to BatchSize.
|
||||||
|
func (s *StdNetBind) BatchSize() int {
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
return IdealBatchSize
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var err1, err2 error
|
||||||
|
if s.ipv4 != nil {
|
||||||
|
err1 = s.ipv4.Close()
|
||||||
|
s.ipv4 = nil
|
||||||
|
s.ipv4PC = nil
|
||||||
|
}
|
||||||
|
if s.ipv6 != nil {
|
||||||
|
err2 = s.ipv6.Close()
|
||||||
|
s.ipv6 = nil
|
||||||
|
s.ipv6PC = nil
|
||||||
|
}
|
||||||
|
s.blackhole4 = false
|
||||||
|
s.blackhole6 = false
|
||||||
|
s.ipv4TxOffload = false
|
||||||
|
s.ipv4RxOffload = false
|
||||||
|
s.ipv6TxOffload = false
|
||||||
|
s.ipv6RxOffload = false
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrUDPGSODisabled struct {
|
||||||
|
onLaddr string
|
||||||
|
RetryErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ErrUDPGSODisabled) Error() string {
|
||||||
|
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ErrUDPGSODisabled) Unwrap() error {
|
||||||
|
return e.RetryErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
blackhole := s.blackhole4
|
||||||
|
conn := s.ipv4
|
||||||
|
offload := s.ipv4TxOffload
|
||||||
|
br := batchWriter(s.ipv4PC)
|
||||||
|
is6 := false
|
||||||
|
if endpoint.DstIP().Is6() {
|
||||||
|
blackhole = s.blackhole6
|
||||||
|
conn = s.ipv6
|
||||||
|
br = s.ipv6PC
|
||||||
|
is6 = true
|
||||||
|
offload = s.ipv6TxOffload
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if blackhole {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return syscall.EAFNOSUPPORT
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := s.getMessages()
|
||||||
|
defer s.putMessages(msgs)
|
||||||
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||||
|
defer s.udpAddrPool.Put(ua)
|
||||||
|
if is6 {
|
||||||
|
as16 := endpoint.DstIP().As16()
|
||||||
|
copy(ua.IP, as16[:])
|
||||||
|
ua.IP = ua.IP[:16]
|
||||||
|
} else {
|
||||||
|
as4 := endpoint.DstIP().As4()
|
||||||
|
copy(ua.IP, as4[:])
|
||||||
|
ua.IP = ua.IP[:4]
|
||||||
|
}
|
||||||
|
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
|
||||||
|
var (
|
||||||
|
retried bool
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
retry:
|
||||||
|
if offload {
|
||||||
|
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
|
||||||
|
err = s.send(conn, br, (*msgs)[:n])
|
||||||
|
if err != nil && offload && errShouldDisableUDPGSO(err) {
|
||||||
|
offload = false
|
||||||
|
s.mu.Lock()
|
||||||
|
if is6 {
|
||||||
|
s.ipv6TxOffload = false
|
||||||
|
} else {
|
||||||
|
s.ipv4TxOffload = false
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := range bufs {
|
||||||
|
(*msgs)[i].Addr = ua
|
||||||
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
|
||||||
|
}
|
||||||
|
err = s.send(conn, br, (*msgs)[:len(bufs)])
|
||||||
|
}
|
||||||
|
if retried {
|
||||||
|
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
|
||||||
|
var (
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
start int
|
||||||
|
)
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
for {
|
||||||
|
n, err = pc.WriteBatch(msgs[start:], 0)
|
||||||
|
if err != nil || n == len(msgs[start:]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
start += n
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, msg := range msgs {
|
||||||
|
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Exceeding these values results in EMSGSIZE. They account for layer3 and
|
||||||
|
// layer4 headers. IPv6 does not need to account for itself as the payload
|
||||||
|
// length field is self excluding.
|
||||||
|
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
||||||
|
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
||||||
|
|
||||||
|
// This is a hard limit imposed by the kernel.
|
||||||
|
udpSegmentMaxDatagrams = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
type setGSOFunc func(control *[]byte, gsoSize uint16)
|
||||||
|
|
||||||
|
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
|
||||||
|
var (
|
||||||
|
base = -1 // index of msg we are currently coalescing into
|
||||||
|
gsoSize int // segmentation size of msgs[base]
|
||||||
|
dgramCnt int // number of dgrams coalesced into msgs[base]
|
||||||
|
endBatch bool // tracking flag to start a new batch on next iteration of bufs
|
||||||
|
)
|
||||||
|
maxPayloadLen := maxIPv4PayloadLen
|
||||||
|
if ep.DstIP().Is6() {
|
||||||
|
maxPayloadLen = maxIPv6PayloadLen
|
||||||
|
}
|
||||||
|
for i, buf := range bufs {
|
||||||
|
if i > 0 {
|
||||||
|
msgLen := len(buf)
|
||||||
|
baseLenBefore := len(msgs[base].Buffers[0])
|
||||||
|
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
||||||
|
if msgLen+baseLenBefore <= maxPayloadLen &&
|
||||||
|
msgLen <= gsoSize &&
|
||||||
|
msgLen <= freeBaseCap &&
|
||||||
|
dgramCnt < udpSegmentMaxDatagrams &&
|
||||||
|
!endBatch {
|
||||||
|
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
|
||||||
|
if i == len(bufs)-1 {
|
||||||
|
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||||
|
}
|
||||||
|
dgramCnt++
|
||||||
|
if msgLen < gsoSize {
|
||||||
|
// A smaller than gsoSize packet on the tail is legal, but
|
||||||
|
// it must end the batch.
|
||||||
|
endBatch = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dgramCnt > 1 {
|
||||||
|
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||||
|
}
|
||||||
|
// Reset prior to incrementing base since we are preparing to start a
|
||||||
|
// new potential batch.
|
||||||
|
endBatch = false
|
||||||
|
base++
|
||||||
|
gsoSize = len(buf)
|
||||||
|
setSrcControl(&msgs[base].OOB, ep)
|
||||||
|
msgs[base].Buffers[0] = buf
|
||||||
|
msgs[base].Addr = addr
|
||||||
|
dgramCnt = 1
|
||||||
|
}
|
||||||
|
return base + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
type getGSOFunc func(control []byte) (int, error)
|
||||||
|
|
||||||
|
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
||||||
|
for i := firstMsgAt; i < len(msgs); i++ {
|
||||||
|
msg := &msgs[i]
|
||||||
|
if msg.N == 0 {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
gsoSize int
|
||||||
|
start int
|
||||||
|
end = msg.N
|
||||||
|
numToSplit = 1
|
||||||
|
)
|
||||||
|
gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
if gsoSize > 0 {
|
||||||
|
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||||
|
end = gsoSize
|
||||||
|
}
|
||||||
|
for j := 0; j < numToSplit; j++ {
|
||||||
|
if n > i {
|
||||||
|
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||||
|
}
|
||||||
|
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||||
|
msgs[n].N = copied
|
||||||
|
msgs[n].Addr = msg.Addr
|
||||||
|
start = end
|
||||||
|
end += gsoSize
|
||||||
|
if end > msg.N {
|
||||||
|
end = msg.N
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
if i != n-1 {
|
||||||
|
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||||
|
// of splitting, so we only zero the source msg len when it is not
|
||||||
|
// the destination of the last split operation above.
|
||||||
|
msg.N = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
131
wgstack/conn/conn.go
Normal file
131
wgstack/conn/conn.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
||||||
|
)
|
||||||
|
|
||||||
|
// A ReceiveFunc receives at least one packet from the network and writes them
|
||||||
|
// into packets. On a successful read it returns the number of elements of
|
||||||
|
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
||||||
|
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
||||||
|
// and eps slice with a length greater than or equal to the length of packets.
|
||||||
|
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
||||||
|
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
||||||
|
|
||||||
|
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||||
|
//
|
||||||
|
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
||||||
|
// depending on the platform-specific implementation.
|
||||||
|
type Bind interface {
|
||||||
|
// Open puts the Bind into a listening state on a given port and reports the actual
|
||||||
|
// port that it bound to. Passing zero results in a random selection.
|
||||||
|
// fns is the set of functions that will be called to receive packets.
|
||||||
|
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
|
||||||
|
|
||||||
|
// Close closes the Bind listener.
|
||||||
|
// All fns returned by Open must return net.ErrClosed after a call to Close.
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// SetMark sets the mark for each packet sent through this Bind.
|
||||||
|
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||||
|
SetMark(mark uint32) error
|
||||||
|
|
||||||
|
// Send writes one or more packets in bufs to address ep. The length of
|
||||||
|
// bufs must not exceed BatchSize().
|
||||||
|
Send(bufs [][]byte, ep Endpoint) error
|
||||||
|
|
||||||
|
// ParseEndpoint creates a new endpoint from a string.
|
||||||
|
ParseEndpoint(s string) (Endpoint, error)
|
||||||
|
|
||||||
|
// BatchSize is the number of buffers expected to be passed to
|
||||||
|
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||||
|
BatchSize() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// BindSocketToInterface is implemented by Bind objects that support being
|
||||||
|
// tied to a single network interface. Used by wireguard-windows.
|
||||||
|
type BindSocketToInterface interface {
|
||||||
|
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
|
||||||
|
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeekLookAtSocketFd is implemented by Bind objects that support having their
|
||||||
|
// file descriptor peeked at. Used by wireguard-android.
|
||||||
|
type PeekLookAtSocketFd interface {
|
||||||
|
PeekLookAtSocketFd4() (fd int, err error)
|
||||||
|
PeekLookAtSocketFd6() (fd int, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Endpoint maintains the source/destination caching for a peer.
|
||||||
|
//
|
||||||
|
// dst: the remote address of a peer ("endpoint" in uapi terminology)
|
||||||
|
// src: the local address from which datagrams originate going to the peer
|
||||||
|
type Endpoint interface {
|
||||||
|
ClearSrc() // clears the source address
|
||||||
|
SrcToString() string // returns the local source address (ip:port)
|
||||||
|
DstToString() string // returns the destination address (ip:port)
|
||||||
|
DstToBytes() []byte // used for mac2 cookie calculations
|
||||||
|
DstIP() netip.Addr
|
||||||
|
SrcIP() netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBindAlreadyOpen = errors.New("bind is already open")
|
||||||
|
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (fn ReceiveFunc) PrettyName() string {
|
||||||
|
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
||||||
|
// 0. cheese/taco.beansIPv6.func12.func21218-fm
|
||||||
|
name = strings.TrimSuffix(name, "-fm")
|
||||||
|
// 1. cheese/taco.beansIPv6.func12.func21218
|
||||||
|
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
|
||||||
|
name = name[idx+1:]
|
||||||
|
// 2. taco.beansIPv6.func12.func21218
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
var idx int
|
||||||
|
for idx = len(name) - 1; idx >= 0; idx-- {
|
||||||
|
if name[idx] < '0' || name[idx] > '9' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx == len(name)-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
const dotFunc = ".func"
|
||||||
|
if !strings.HasSuffix(name[:idx+1], dotFunc) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
name = name[:idx+1-len(dotFunc)]
|
||||||
|
// 3. taco.beansIPv6.func12
|
||||||
|
// 4. taco.beansIPv6
|
||||||
|
}
|
||||||
|
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
|
||||||
|
name = name[idx+1:]
|
||||||
|
// 5. beansIPv6
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return fmt.Sprintf("%p", fn)
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(name, "IPv4") {
|
||||||
|
return "v4"
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(name, "IPv6") {
|
||||||
|
return "v6"
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
222
wgstack/conn/controlfns.go
Normal file
222
wgstack/conn/controlfns.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/cilium/ebpf"
|
||||||
|
"github.com/cilium/ebpf/asm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
||||||
|
// the max supported by a default configuration of macOS. Some platforms will
|
||||||
|
// silently clamp the value to other maximums, such as linux clamping to
|
||||||
|
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
||||||
|
// around this limitation)
|
||||||
|
const socketBufferSize = 7 << 20
|
||||||
|
|
||||||
|
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||||
|
// It is used to apply platform specific configuration to the socket prior to
|
||||||
|
// bind.
|
||||||
|
type controlFn func(network, address string, c syscall.RawConn) error
|
||||||
|
|
||||||
|
// controlFns is a list of functions that are called from the listen config
|
||||||
|
// that can apply socket options.
|
||||||
|
var controlFns = []controlFn{}
|
||||||
|
|
||||||
|
const SO_ATTACH_REUSEPORT_EBPF = 52
|
||||||
|
|
||||||
|
//Create eBPF program that returns a hash to distribute packets
|
||||||
|
|
||||||
|
func createReuseportProgram() (*ebpf.Program, error) {
|
||||||
|
// This program uses the packet's hash and returns it modulo number of sockets
|
||||||
|
// Simple version: just return a counter-based distribution
|
||||||
|
//instructions := asm.Instructions{
|
||||||
|
// // Load the skb->hash value (already computed by kernel)
|
||||||
|
// asm.LoadMem(asm.R0, asm.R1, int16(unsafe.Offsetof(unix.XDPMd{}.RxQueueIndex)), asm.Word),
|
||||||
|
// asm.Return(),
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//// Alternative: simpler round-robin approach
|
||||||
|
//// This returns the CPU number, effectively round-robin
|
||||||
|
//instructions := asm.Instructions{
|
||||||
|
// asm.Mov.Reg(asm.R0, asm.R1), // Move ctx to R0
|
||||||
|
// asm.LoadMem(asm.R0, asm.R1, 0, asm.Word), // Load some field
|
||||||
|
// asm.Return(),
|
||||||
|
//}
|
||||||
|
|
||||||
|
// Better: Use BPF helper to get random/hash value
|
||||||
|
//instructions := asm.Instructions{
|
||||||
|
// // Call get_prandom_u32() to get random value for distribution
|
||||||
|
// asm.Mov.Imm(asm.R0, 0),
|
||||||
|
// asm.Call.Label("get_prandom_u32"),
|
||||||
|
// asm.Return(),
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
|
||||||
|
// Type: ebpf.SocketFilter,
|
||||||
|
// Instructions: instructions,
|
||||||
|
// License: "GPL",
|
||||||
|
//})
|
||||||
|
|
||||||
|
//instructions := asm.Instructions{
|
||||||
|
// // R1 contains pointer to skb
|
||||||
|
// // Load skb->hash at offset 0x20 (may vary by kernel, but 0x20 is common)
|
||||||
|
// asm.LoadMem(asm.R0, asm.R1, 0x20, asm.Word),
|
||||||
|
//
|
||||||
|
// // If hash is 0, use rxhash instead (fallback)
|
||||||
|
// asm.JEq.Imm(asm.R0, 0, "use_rxhash"),
|
||||||
|
// asm.Return().Sym("return"),
|
||||||
|
//
|
||||||
|
// // Fallback: load rxhash
|
||||||
|
// asm.LoadMem(asm.R0, asm.R1, 0x24, asm.Word).Sym("use_rxhash"),
|
||||||
|
// asm.Return(),
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
|
||||||
|
// Type: ebpf.SkReuseport,
|
||||||
|
// Instructions: instructions,
|
||||||
|
// License: "GPL",
|
||||||
|
//})
|
||||||
|
|
||||||
|
//instructions := asm.Instructions{
|
||||||
|
// // R1 = ctx (sk_reuseport_md)
|
||||||
|
// // R2 = sk_reuseport map (we'll use NULL/0 for default behavior)
|
||||||
|
// // R3 = key (select socket index)
|
||||||
|
// // R4 = flags
|
||||||
|
//
|
||||||
|
// // Simple approach: use the hash field from sk_reuseport_md
|
||||||
|
// // struct sk_reuseport_md { ... __u32 hash; ... } at offset 24
|
||||||
|
// asm.Mov.Reg(asm.R6, asm.R1), // Save ctx
|
||||||
|
//
|
||||||
|
// // Load the hash value at offset 24
|
||||||
|
// asm.LoadMem(asm.R2, asm.R6, 24, asm.Word),
|
||||||
|
//
|
||||||
|
// // Call bpf_sk_select_reuseport(ctx, map, key, flags)
|
||||||
|
// asm.Mov.Reg(asm.R1, asm.R6), // ctx
|
||||||
|
// asm.Mov.Imm(asm.R2, 0), // map (NULL = use default)
|
||||||
|
// asm.Mov.Reg(asm.R3, asm.R2), // key = hash we loaded (in R2)
|
||||||
|
// asm.Mov.Imm(asm.R4, 0), // flags
|
||||||
|
// asm.Call.Label("sk_select_reuseport"),
|
||||||
|
//
|
||||||
|
// // Return 0
|
||||||
|
// asm.Mov.Imm(asm.R0, 0),
|
||||||
|
// asm.Return(),
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
|
||||||
|
// Type: ebpf.SkReuseport,
|
||||||
|
// Instructions: instructions,
|
||||||
|
// License: "GPL",
|
||||||
|
//})
|
||||||
|
|
||||||
|
instructions := asm.Instructions{
|
||||||
|
// R1 = ctx (sk_reuseport_md pointer)
|
||||||
|
// Load hash from sk_reuseport_md at offset 24
|
||||||
|
//asm.LoadMem(asm.R0, asm.R1, 20, asm.Word),
|
||||||
|
|
||||||
|
// R1 = ctx (save it)
|
||||||
|
asm.Mov.Reg(asm.R6, asm.R1),
|
||||||
|
|
||||||
|
// Prepare string on stack: "BPF called!\n"
|
||||||
|
// We need to build the format string on the stack
|
||||||
|
asm.Mov.Reg(asm.R1, asm.R10), // R1 = frame pointer
|
||||||
|
asm.Add.Imm(asm.R1, -16), // R1 = stack location for string
|
||||||
|
|
||||||
|
// Write "BPF called!\n" to stack (we'll use a simpler version)
|
||||||
|
// Store immediate 64-bit values
|
||||||
|
asm.StoreImm(asm.R1, 0, 0x2066706220, asm.DWord), // "bpf "
|
||||||
|
asm.StoreImm(asm.R1, 8, 0x0a21, asm.DWord), // "!\n"
|
||||||
|
|
||||||
|
// Call bpf_trace_printk(fmt, fmt_size)
|
||||||
|
// R1 already points to format string
|
||||||
|
asm.Mov.Imm(asm.R2, 16), // R2 = format size
|
||||||
|
asm.Call.Label("bpf_printk"),
|
||||||
|
|
||||||
|
// Return 0 (send to socket 0 for testing)
|
||||||
|
asm.Mov.Imm(asm.R0, 0),
|
||||||
|
asm.Return(),
|
||||||
|
|
||||||
|
//asm.Mov.Imm(asm.R0, 0),
|
||||||
|
//// Just return the hash directly
|
||||||
|
//// The kernel will automatically modulo by number of sockets
|
||||||
|
//asm.Return(),
|
||||||
|
}
|
||||||
|
|
||||||
|
prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
|
||||||
|
Type: ebpf.SkReuseport,
|
||||||
|
Instructions: instructions,
|
||||||
|
License: "GPL",
|
||||||
|
})
|
||||||
|
|
||||||
|
return prog, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//func createReuseportProgram() (*ebpf.Program, error) {
|
||||||
|
// // Try offset 20 (common in newer kernels)
|
||||||
|
// instructions := asm.Instructions{
|
||||||
|
// asm.LoadMem(asm.R0, asm.R1, 20, asm.Word),
|
||||||
|
// asm.Return(),
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
|
||||||
|
// Type: ebpf.SkReuseport,
|
||||||
|
// Instructions: instructions,
|
||||||
|
// License: "GPL",
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// return prog, err
|
||||||
|
//}
|
||||||
|
|
||||||
|
func reusePortHax(fd uintptr) error {
|
||||||
|
prog, err := createReuseportProgram()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create eBPF program: %w", err)
|
||||||
|
}
|
||||||
|
//defer prog.Close()
|
||||||
|
sockErr := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, prog.FD())
|
||||||
|
if sockErr != nil {
|
||||||
|
return sockErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var EvilFdZero uintptr
|
||||||
|
|
||||||
|
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||||
|
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||||
|
// information OOB configuration for sticky sockets.
|
||||||
|
func listenConfig(q int) *net.ListenConfig {
|
||||||
|
return &net.ListenConfig{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
for _, fn := range controlFns {
|
||||||
|
if err := fn(network, address, c); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if q == 0 {
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
EvilFdZero = fd
|
||||||
|
})
|
||||||
|
// var e error
|
||||||
|
// err := c.Control(func(fd uintptr) {
|
||||||
|
// e = reusePortHax(fd)
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// if e != nil {
|
||||||
|
// return e
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
66
wgstack/conn/controlfns_linux.go
Normal file
66
wgstack/conn/controlfns_linux.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
controlFns = append(controlFns,
|
||||||
|
|
||||||
|
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
||||||
|
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
||||||
|
// fail silently - the result of failure is lower performance on very fast
|
||||||
|
// links or high latency links.
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
return c.Control(func(fd uintptr) {
|
||||||
|
// Set up to *mem_max
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||||
|
// Set beyond *mem_max if CAP_NET_ADMIN
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) //todo!!!
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) //todo!!!
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_UDP, unix.UDP_SEGMENT, 0xffff) //todo!!!
|
||||||
|
//print(err.Error())
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||||
|
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
var err error
|
||||||
|
switch network {
|
||||||
|
case "udp4":
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "udp6":
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
9
wgstack/conn/default.go
Normal file
9
wgstack/conn/default.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func NewDefaultBind() Bind { return NewStdNetBind() }
|
||||||
12
wgstack/conn/errors_default.go
Normal file
12
wgstack/conn/errors_default.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
26
wgstack/conn/errors_linux.go
Normal file
26
wgstack/conn/errors_linux.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
var serr *os.SyscallError
|
||||||
|
if errors.As(err, &serr) {
|
||||||
|
// EIO is returned by udp_send_skb() if the device driver does not have
|
||||||
|
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
||||||
|
// See:
|
||||||
|
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||||
|
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||||
|
return serr.Err == unix.EIO
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
15
wgstack/conn/features_default.go
Normal file
15
wgstack/conn/features_default.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
return
|
||||||
|
}
|
||||||
33
wgstack/conn/features_linux.go
Normal file
33
wgstack/conn/features_linux.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
rc, err := conn.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a := 0
|
||||||
|
err = rc.Control(func(fd uintptr) {
|
||||||
|
a, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||||
|
|
||||||
|
txOffload = err == nil
|
||||||
|
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
|
||||||
|
rxOffload = errSyscall == nil && opt == 1
|
||||||
|
})
|
||||||
|
fmt.Printf("%d", a)
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return txOffload, rxOffload
|
||||||
|
}
|
||||||
21
wgstack/conn/gso_default.go
Normal file
21
wgstack/conn/gso_default.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
||||||
|
// offloading control data.
|
||||||
|
const gsoControlSize = 0
|
||||||
65
wgstack/conn/gso_linux.go
Normal file
65
wgstack/conn/gso_linux.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sizeOfGSOData = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
var (
|
||||||
|
hdr unix.Cmsghdr
|
||||||
|
data []byte
|
||||||
|
rem = control
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(rem) > unix.SizeofCmsghdr {
|
||||||
|
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||||
|
}
|
||||||
|
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
||||||
|
var gso uint16
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
||||||
|
return int(gso), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
||||||
|
// data in control untouched.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
existingLen := len(*control)
|
||||||
|
avail := cap(*control) - existingLen
|
||||||
|
space := unix.CmsgSpace(sizeOfGSOData)
|
||||||
|
if avail < space {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*control = (*control)[:cap(*control)]
|
||||||
|
gsoControl := (*control)[existingLen:]
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
||||||
|
hdr.Level = unix.SOL_UDP
|
||||||
|
hdr.Type = unix.UDP_SEGMENT
|
||||||
|
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
||||||
|
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
||||||
|
*control = (*control)[:existingLen+space]
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling UDP
|
||||||
|
// offloading control data.
|
||||||
|
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
||||||
64
wgstack/conn/mark_unix.go
Normal file
64
wgstack/conn/mark_unix.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
//go:build linux || openbsd || freebsd
|
||||||
|
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
var fwmarkIoctl int
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux", "android":
|
||||||
|
fwmarkIoctl = 36 /* unix.SO_MARK */
|
||||||
|
case "freebsd":
|
||||||
|
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
|
||||||
|
case "openbsd":
|
||||||
|
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||||
|
var operr error
|
||||||
|
if fwmarkIoctl == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if s.ipv4 != nil {
|
||||||
|
fd, err := s.ipv4.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = fd.Control(func(fd uintptr) {
|
||||||
|
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
err = operr
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.ipv6 != nil {
|
||||||
|
fd, err := s.ipv6.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = fd.Control(func(fd uintptr) {
|
||||||
|
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
err = operr
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
42
wgstack/conn/sticky_default.go
Normal file
42
wgstack/conn/sticky_default.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
||||||
|
// {get,set}srcControl feature set, but use alternatively named flags and need
|
||||||
|
// ports and require testing.
|
||||||
|
|
||||||
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||||
|
// offloading control data.
|
||||||
|
const stickyControlSize = 0
|
||||||
|
|
||||||
|
const StdNetSupportsStickySockets = false
|
||||||
105
wgstack/conn/sticky_linux.go
Normal file
105
wgstack/conn/sticky_linux.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||||
|
switch len(e.src) {
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
return netip.AddrFrom4(info.Spec_dst)
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
// TODO: set zone. in order to do so we need to check if the address is
|
||||||
|
// link local, and if it is perform a syscall to turn the ifindex into a
|
||||||
|
// zone string because netip uses string zones.
|
||||||
|
return netip.AddrFrom16(info.Addr)
|
||||||
|
}
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||||
|
switch len(e.src) {
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
return info.Ifindex
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
return int32(info.Ifindex)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcToString() string {
|
||||||
|
return e.SrcIP().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||||
|
ep.ClearSrc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
hdr unix.Cmsghdr
|
||||||
|
data []byte
|
||||||
|
rem []byte = control
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(rem) > unix.SizeofCmsghdr {
|
||||||
|
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Level == unix.IPPROTO_IP &&
|
||||||
|
hdr.Type == unix.IP_PKTINFO {
|
||||||
|
|
||||||
|
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
|
||||||
|
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||||
|
}
|
||||||
|
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
||||||
|
|
||||||
|
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||||
|
copy(ep.src, hdrBuf)
|
||||||
|
copy(ep.src[unix.CmsgLen(0):], data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||||
|
hdr.Type == unix.IPV6_PKTINFO {
|
||||||
|
|
||||||
|
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
|
||||||
|
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
||||||
|
|
||||||
|
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||||
|
copy(ep.src, hdrBuf)
|
||||||
|
copy(ep.src[unix.CmsgLen(0):], data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
||||||
|
// and source ifindex found in ep. control's len will be set to 0 in the event
|
||||||
|
// that ep is a default value.
|
||||||
|
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
|
if cap(*control) < len(ep.src) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*control = (*control)[:0]
|
||||||
|
*control = append(*control, ep.src...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||||
|
// offloading control data.
|
||||||
|
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
||||||
|
|
||||||
|
const StdNetSupportsStickySockets = true
|
||||||
42
wgstack/tun/checksum.go
Normal file
42
wgstack/tun/checksum.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package tun
|
||||||
|
|
||||||
|
import "encoding/binary"
|
||||||
|
|
||||||
|
// TODO: Explore SIMD and/or other assembly optimizations.
|
||||||
|
func checksumNoFold(b []byte, initial uint64) uint64 {
|
||||||
|
ac := initial
|
||||||
|
i := 0
|
||||||
|
n := len(b)
|
||||||
|
for n >= 4 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
|
||||||
|
n -= 4
|
||||||
|
i += 4
|
||||||
|
}
|
||||||
|
for n >= 2 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
|
||||||
|
n -= 2
|
||||||
|
i += 2
|
||||||
|
}
|
||||||
|
if n == 1 {
|
||||||
|
ac += uint64(b[i]) << 8
|
||||||
|
}
|
||||||
|
return ac
|
||||||
|
}
|
||||||
|
|
||||||
|
func checksum(b []byte, initial uint64) uint16 {
|
||||||
|
ac := checksumNoFold(b, initial)
|
||||||
|
ac = (ac >> 16) + (ac & 0xffff)
|
||||||
|
ac = (ac >> 16) + (ac & 0xffff)
|
||||||
|
ac = (ac >> 16) + (ac & 0xffff)
|
||||||
|
ac = (ac >> 16) + (ac & 0xffff)
|
||||||
|
return uint16(ac)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
|
||||||
|
sum := checksumNoFold(srcAddr, 0)
|
||||||
|
sum = checksumNoFold(dstAddr, sum)
|
||||||
|
sum = checksumNoFold([]byte{0, protocol}, sum)
|
||||||
|
tmp := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(tmp, totalLen)
|
||||||
|
return checksumNoFold(tmp, sum)
|
||||||
|
}
|
||||||
3
wgstack/tun/export.go
Normal file
3
wgstack/tun/export.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package tun
|
||||||
|
|
||||||
|
const VirtioNetHdrLen = virtioNetHdrLen
|
||||||
630
wgstack/tun/tcp_offload_linux.go
Normal file
630
wgstack/tun/tcp_offload_linux.go
Normal file
@@ -0,0 +1,630 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrTooManySegments = errors.New("tun: too many segments for TSO")
|
||||||
|
|
||||||
|
const tcpFlagsOffset = 13
|
||||||
|
|
||||||
|
const (
|
||||||
|
tcpFlagFIN uint8 = 0x01
|
||||||
|
tcpFlagPSH uint8 = 0x08
|
||||||
|
tcpFlagACK uint8 = 0x10
|
||||||
|
)
|
||||||
|
|
||||||
|
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
||||||
|
// kernel symbol is virtio_net_hdr.
|
||||||
|
type virtioNetHdr struct {
|
||||||
|
flags uint8
|
||||||
|
gsoType uint8
|
||||||
|
hdrLen uint16
|
||||||
|
gsoSize uint16
|
||||||
|
csumStart uint16
|
||||||
|
csumOffset uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *virtioNetHdr) decode(b []byte) error {
|
||||||
|
if len(b) < virtioNetHdrLen {
|
||||||
|
return io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *virtioNetHdr) encode(b []byte) error {
|
||||||
|
if len(b) < virtioNetHdrLen {
|
||||||
|
return io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
||||||
|
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
||||||
|
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
||||||
|
)
|
||||||
|
|
||||||
|
// flowKey represents the key for a flow.
|
||||||
|
type flowKey struct {
|
||||||
|
srcAddr, dstAddr [16]byte
|
||||||
|
srcPort, dstPort uint16
|
||||||
|
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
|
||||||
|
type tcpGROTable struct {
|
||||||
|
itemsByFlow map[flowKey][]tcpGROItem
|
||||||
|
itemsPool [][]tcpGROItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTCPGROTable() *tcpGROTable {
|
||||||
|
t := &tcpGROTable{
|
||||||
|
itemsByFlow: make(map[flowKey][]tcpGROItem, wgconn.IdealBatchSize),
|
||||||
|
itemsPool: make([][]tcpGROItem, wgconn.IdealBatchSize),
|
||||||
|
}
|
||||||
|
for i := range t.itemsPool {
|
||||||
|
t.itemsPool[i] = make([]tcpGROItem, 0, wgconn.IdealBatchSize)
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
|
||||||
|
key := flowKey{}
|
||||||
|
addrSize := dstAddr - srcAddr
|
||||||
|
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
|
||||||
|
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
|
||||||
|
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
||||||
|
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
||||||
|
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||||
|
// returning the packets found for the flow, or inserting a new one if none
|
||||||
|
// is found.
|
||||||
|
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
||||||
|
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||||
|
items, ok := t.itemsByFlow[key]
|
||||||
|
if ok {
|
||||||
|
return items, ok
|
||||||
|
}
|
||||||
|
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||||
|
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert an item in the table for the provided packet and packet metadata.
|
||||||
|
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
||||||
|
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||||
|
item := tcpGROItem{
|
||||||
|
key: key,
|
||||||
|
bufsIndex: uint16(bufsIndex),
|
||||||
|
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
||||||
|
iphLen: uint8(tcphOffset),
|
||||||
|
tcphLen: uint8(tcphLen),
|
||||||
|
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
||||||
|
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
||||||
|
}
|
||||||
|
items, ok := t.itemsByFlow[key]
|
||||||
|
if !ok {
|
||||||
|
items = t.newItems()
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
t.itemsByFlow[key] = items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
||||||
|
items, _ := t.itemsByFlow[item.key]
|
||||||
|
items[i] = item
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
|
||||||
|
items, _ := t.itemsByFlow[key]
|
||||||
|
items = append(items[:i], items[i+1:]...)
|
||||||
|
t.itemsByFlow[key] = items
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
||||||
|
// of a GRO evaluation across a vector of packets.
|
||||||
|
type tcpGROItem struct {
|
||||||
|
key flowKey
|
||||||
|
sentSeq uint32 // the sequence number
|
||||||
|
bufsIndex uint16 // the index into the original bufs slice
|
||||||
|
numMerged uint16 // the number of packets merged into this item
|
||||||
|
gsoSize uint16 // payload size
|
||||||
|
iphLen uint8 // ip header len
|
||||||
|
tcphLen uint8 // tcp header len
|
||||||
|
pshSet bool // psh flag is set
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) newItems() []tcpGROItem {
|
||||||
|
var items []tcpGROItem
|
||||||
|
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) reset() {
|
||||||
|
for k, items := range t.itemsByFlow {
|
||||||
|
items = items[:0]
|
||||||
|
t.itemsPool = append(t.itemsPool, items)
|
||||||
|
delete(t.itemsByFlow, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// canCoalesce represents the outcome of checking if two TCP packets are
|
||||||
|
// candidates for coalescing.
|
||||||
|
type canCoalesce int
|
||||||
|
|
||||||
|
const (
|
||||||
|
coalescePrepend canCoalesce = -1
|
||||||
|
coalesceUnavailable canCoalesce = 0
|
||||||
|
coalesceAppend canCoalesce = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||||
|
// described by item. This function makes considerations that match the kernel's
|
||||||
|
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
||||||
|
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||||
|
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||||
|
if tcphLen != item.tcphLen {
|
||||||
|
// cannot coalesce with unequal tcp options len
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if tcphLen > 20 {
|
||||||
|
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
||||||
|
// cannot coalesce with unequal tcp options
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pkt[0]>>4 == 6 {
|
||||||
|
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
|
||||||
|
// cannot coalesce with unequal Traffic class values
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if pkt[7] != pktTarget[7] {
|
||||||
|
// cannot coalesce with unequal Hop limit values
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if pkt[1] != pktTarget[1] {
|
||||||
|
// cannot coalesce with unequal ToS values
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if pkt[6]>>5 != pktTarget[6]>>5 {
|
||||||
|
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
||||||
|
// further up the stack.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if pkt[8] != pktTarget[8] {
|
||||||
|
// cannot coalesce with unequal TTL values
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// seq adjacency
|
||||||
|
lhsLen := item.gsoSize
|
||||||
|
lhsLen += item.numMerged * item.gsoSize
|
||||||
|
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
||||||
|
if item.pshSet {
|
||||||
|
// We cannot append to a segment that has the PSH flag set, PSH
|
||||||
|
// can only be set on the final segment in a reassembled group.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
||||||
|
// A smaller than gsoSize packet has been appended previously.
|
||||||
|
// Nothing can come after a smaller packet on the end.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize > item.gsoSize {
|
||||||
|
// We cannot have a larger packet following a smaller one.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
return coalesceAppend
|
||||||
|
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
||||||
|
if pshSet {
|
||||||
|
// We cannot prepend with a segment that has the PSH flag set, PSH
|
||||||
|
// can only be set on the final segment in a reassembled group.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize < item.gsoSize {
|
||||||
|
// We cannot have a larger packet following a smaller one.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
||||||
|
// There's at least one previous merge, and we're larger than all
|
||||||
|
// previous. This would put multiple smaller packets on the end.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
return coalescePrepend
|
||||||
|
}
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
|
||||||
|
srcAddrAt := ipv4SrcAddrOffset
|
||||||
|
addrSize := 4
|
||||||
|
if isV6 {
|
||||||
|
srcAddrAt = ipv6SrcAddrOffset
|
||||||
|
addrSize = 16
|
||||||
|
}
|
||||||
|
tcpTotalLen := uint16(len(pkt) - int(iphLen))
|
||||||
|
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
|
||||||
|
return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// coalesceResult represents the result of attempting to coalesce two TCP
|
||||||
|
// packets.
|
||||||
|
type coalesceResult int
|
||||||
|
|
||||||
|
const (
|
||||||
|
coalesceInsufficientCap coalesceResult = 0
|
||||||
|
coalescePSHEnding coalesceResult = 1
|
||||||
|
coalesceItemInvalidCSum coalesceResult = 2
|
||||||
|
coalescePktInvalidCSum coalesceResult = 3
|
||||||
|
coalesceSuccess coalesceResult = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
||||||
|
// item, returning the outcome. This function may swap bufs elements in the
|
||||||
|
// event of a prepend as item's bufs index is already being tracked for writing
|
||||||
|
// to a Device.
|
||||||
|
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||||
|
var pktHead []byte // the packet that will end up at the front
|
||||||
|
headersLen := item.iphLen + item.tcphLen
|
||||||
|
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||||
|
|
||||||
|
// Copy data
|
||||||
|
if mode == coalescePrepend {
|
||||||
|
pktHead = pkt
|
||||||
|
if cap(pkt)-bufsOffset < coalescedLen {
|
||||||
|
// We don't want to allocate a new underlying array if capacity is
|
||||||
|
// too small.
|
||||||
|
return coalesceInsufficientCap
|
||||||
|
}
|
||||||
|
if pshSet {
|
||||||
|
return coalescePSHEnding
|
||||||
|
}
|
||||||
|
if item.numMerged == 0 {
|
||||||
|
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
||||||
|
return coalesceItemInvalidCSum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
||||||
|
return coalescePktInvalidCSum
|
||||||
|
}
|
||||||
|
item.sentSeq = seq
|
||||||
|
extendBy := coalescedLen - len(pktHead)
|
||||||
|
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
||||||
|
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
||||||
|
// Flip the slice headers in bufs as part of prepend. The index of item
|
||||||
|
// is already being tracked for writing.
|
||||||
|
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
||||||
|
} else {
|
||||||
|
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
||||||
|
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||||
|
// We don't want to allocate a new underlying array if capacity is
|
||||||
|
// too small.
|
||||||
|
return coalesceInsufficientCap
|
||||||
|
}
|
||||||
|
if item.numMerged == 0 {
|
||||||
|
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
||||||
|
return coalesceItemInvalidCSum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
||||||
|
return coalescePktInvalidCSum
|
||||||
|
}
|
||||||
|
if pshSet {
|
||||||
|
// We are appending a segment with PSH set.
|
||||||
|
item.pshSet = pshSet
|
||||||
|
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
||||||
|
}
|
||||||
|
extendBy := len(pkt) - int(headersLen)
|
||||||
|
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||||
|
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||||
|
}
|
||||||
|
|
||||||
|
if gsoSize > item.gsoSize {
|
||||||
|
item.gsoSize = gsoSize
|
||||||
|
}
|
||||||
|
hdr := virtioNetHdr{
|
||||||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||||
|
hdrLen: uint16(headersLen),
|
||||||
|
gsoSize: uint16(item.gsoSize),
|
||||||
|
csumStart: uint16(item.iphLen),
|
||||||
|
csumOffset: 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
|
||||||
|
// (IPv4) header checksum.
|
||||||
|
if isV6 {
|
||||||
|
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||||
|
binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
|
||||||
|
} else {
|
||||||
|
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||||
|
pktHead[10], pktHead[11] = 0, 0 // clear checksum field
|
||||||
|
binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
|
||||||
|
iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
|
||||||
|
binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
|
||||||
|
}
|
||||||
|
hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
|
||||||
|
|
||||||
|
// Calculate the pseudo header checksum and place it at the TCP checksum
|
||||||
|
// offset. Downstream checksum offloading will combine this with computation
|
||||||
|
// of the tcp header and payload checksum.
|
||||||
|
addrLen := 4
|
||||||
|
addrOffset := ipv4SrcAddrOffset
|
||||||
|
if isV6 {
|
||||||
|
addrLen = 16
|
||||||
|
addrOffset = ipv6SrcAddrOffset
|
||||||
|
}
|
||||||
|
srcAddrAt := bufsOffset + addrOffset
|
||||||
|
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||||
|
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||||
|
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
|
||||||
|
binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
|
||||||
|
|
||||||
|
item.numMerged++
|
||||||
|
return coalesceSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4FlagMoreFragments uint8 = 0x20
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4SrcAddrOffset = 12
|
||||||
|
ipv6SrcAddrOffset = 8
|
||||||
|
maxUint16 = 1<<16 - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
||||||
|
// existing packets tracked in table. It will return false when pktI is not
|
||||||
|
// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
|
||||||
|
// should be written to the Device.
|
||||||
|
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
|
||||||
|
pkt := bufs[pktI][offset:]
|
||||||
|
if len(pkt) > maxUint16 {
|
||||||
|
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||||
|
if isV6 {
|
||||||
|
iphLen = 40
|
||||||
|
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||||
|
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||||
|
if totalLen != len(pkt) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
||||||
|
if tcphLen < 20 || tcphLen > 60 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen+tcphLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !isV6 {
|
||||||
|
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||||
|
// no GRO support for fragmented segments for now
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
||||||
|
var pshSet bool
|
||||||
|
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
||||||
|
if tcpFlags != tcpFlagACK {
|
||||||
|
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
pshSet = true
|
||||||
|
}
|
||||||
|
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
||||||
|
// not a candidate if payload len is 0
|
||||||
|
if gsoSize < 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
||||||
|
srcAddrOffset := ipv4SrcAddrOffset
|
||||||
|
addrLen := 4
|
||||||
|
if isV6 {
|
||||||
|
srcAddrOffset = ipv6SrcAddrOffset
|
||||||
|
addrLen = 16
|
||||||
|
}
|
||||||
|
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||||
|
if !existing {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := len(items) - 1; i >= 0; i-- {
|
||||||
|
// In the best case of packets arriving in order iterating in reverse is
|
||||||
|
// more efficient if there are multiple items for a given flow. This
|
||||||
|
// also enables a natural table.deleteAt() in the
|
||||||
|
// coalesceItemInvalidCSum case without the need for index tracking.
|
||||||
|
// This algorithm makes a best effort to coalesce in the event of
|
||||||
|
// unordered packets, where pkt may land anywhere in items from a
|
||||||
|
// sequence number perspective, however once an item is inserted into
|
||||||
|
// the table it is never compared across other items later.
|
||||||
|
item := items[i]
|
||||||
|
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
||||||
|
if can != coalesceUnavailable {
|
||||||
|
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
||||||
|
switch result {
|
||||||
|
case coalesceSuccess:
|
||||||
|
table.updateAt(item, i)
|
||||||
|
return true
|
||||||
|
case coalesceItemInvalidCSum:
|
||||||
|
// delete the item with an invalid csum
|
||||||
|
table.deleteAt(item.key, i)
|
||||||
|
case coalescePktInvalidCSum:
|
||||||
|
// no point in inserting an item that we can't coalesce
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// failed to coalesce with any other packets; store the item in the flow
|
||||||
|
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTCP4NoIPOptions(b []byte) bool {
|
||||||
|
if len(b) < 40 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b[0]>>4 != 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b[0]&0x0F != 5 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b[9] != unix.IPPROTO_TCP {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTCP6NoEH(b []byte) bool {
|
||||||
|
if len(b) < 60 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b[0]>>4 != 6 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b[6] != unix.IPPROTO_TCP {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
||||||
|
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
|
||||||
|
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
||||||
|
// and recycle them across vectors of packets.
|
||||||
|
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
|
||||||
|
for i := range bufs {
|
||||||
|
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
||||||
|
return errors.New("invalid offset")
|
||||||
|
}
|
||||||
|
var coalesced bool
|
||||||
|
switch {
|
||||||
|
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
|
||||||
|
coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
|
||||||
|
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
|
||||||
|
coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
|
||||||
|
}
|
||||||
|
if !coalesced {
|
||||||
|
hdr := virtioNetHdr{}
|
||||||
|
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*toWrite = append(*toWrite, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpTSO splits packets from in into outBuffs, writing the size of each
|
||||||
|
// element into sizes. It returns the number of buffers populated, and/or an
|
||||||
|
// error.
|
||||||
|
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
|
||||||
|
iphLen := int(hdr.csumStart)
|
||||||
|
srcAddrOffset := ipv6SrcAddrOffset
|
||||||
|
addrLen := 16
|
||||||
|
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||||
|
in[10], in[11] = 0, 0 // clear ipv4 header checksum
|
||||||
|
srcAddrOffset = ipv4SrcAddrOffset
|
||||||
|
addrLen = 4
|
||||||
|
}
|
||||||
|
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||||
|
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
|
||||||
|
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
|
||||||
|
nextSegmentDataAt := int(hdr.hdrLen)
|
||||||
|
i := 0
|
||||||
|
for ; nextSegmentDataAt < len(in); i++ {
|
||||||
|
if i == len(outBuffs) {
|
||||||
|
return i - 1, ErrTooManySegments
|
||||||
|
}
|
||||||
|
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
|
||||||
|
if nextSegmentEnd > len(in) {
|
||||||
|
nextSegmentEnd = len(in)
|
||||||
|
}
|
||||||
|
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
||||||
|
totalLen := int(hdr.hdrLen) + segmentDataLen
|
||||||
|
sizes[i] = totalLen
|
||||||
|
out := outBuffs[i][outOffset:]
|
||||||
|
|
||||||
|
copy(out, in[:iphLen])
|
||||||
|
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||||
|
// For IPv4 we are responsible for incrementing the ID field,
|
||||||
|
// updating the total len field, and recalculating the header
|
||||||
|
// checksum.
|
||||||
|
if i > 0 {
|
||||||
|
id := binary.BigEndian.Uint16(out[4:])
|
||||||
|
id += uint16(i)
|
||||||
|
binary.BigEndian.PutUint16(out[4:], id)
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
||||||
|
ipv4CSum := ^checksum(out[:iphLen], 0)
|
||||||
|
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
||||||
|
} else {
|
||||||
|
// For IPv6 we are responsible for updating the payload length field.
|
||||||
|
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCP header
|
||||||
|
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
|
||||||
|
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
|
||||||
|
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
|
||||||
|
if nextSegmentEnd != len(in) {
|
||||||
|
// FIN and PSH should only be set on last segment
|
||||||
|
clearFlags := tcpFlagFIN | tcpFlagPSH
|
||||||
|
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
|
||||||
|
}
|
||||||
|
|
||||||
|
// payload
|
||||||
|
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
||||||
|
|
||||||
|
// TCP checksum
|
||||||
|
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
|
||||||
|
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
|
||||||
|
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
|
||||||
|
tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
|
||||||
|
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
|
||||||
|
|
||||||
|
nextSegmentDataAt += int(hdr.gsoSize)
|
||||||
|
}
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
|
||||||
|
cSumAt := cSumStart + cSumOffset
|
||||||
|
// The initial value at the checksum offset should be summed with the
|
||||||
|
// checksum we compute. This is typically the pseudo-header checksum.
|
||||||
|
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
||||||
|
in[cSumAt], in[cSumAt+1] = 0, 0
|
||||||
|
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
52
wgstack/tun/tun.go
Normal file
52
wgstack/tun/tun.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Event int
|
||||||
|
|
||||||
|
const (
|
||||||
|
EventUp = 1 << iota
|
||||||
|
EventDown
|
||||||
|
EventMTUUpdate
|
||||||
|
)
|
||||||
|
|
||||||
|
type Device interface {
|
||||||
|
// File returns the file descriptor of the device.
|
||||||
|
File() *os.File
|
||||||
|
|
||||||
|
// Read one or more packets from the Device (without any additional headers).
|
||||||
|
// On a successful read it returns the number of packets read, and sets
|
||||||
|
// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
|
||||||
|
// A nonzero offset can be used to instruct the Device on where to begin
|
||||||
|
// reading into each element of the bufs slice.
|
||||||
|
Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
|
||||||
|
|
||||||
|
// Write one or more packets to the device (without any additional headers).
|
||||||
|
// On a successful write it returns the number of packets written. A nonzero
|
||||||
|
// offset can be used to instruct the Device on where to begin writing from
|
||||||
|
// each packet contained within the bufs slice.
|
||||||
|
Write(bufs [][]byte, offset int) (int, error)
|
||||||
|
|
||||||
|
// MTU returns the MTU of the Device.
|
||||||
|
MTU() (int, error)
|
||||||
|
|
||||||
|
// Name returns the current name of the Device.
|
||||||
|
Name() (string, error)
|
||||||
|
|
||||||
|
// Events returns a channel of type Event, which is fed Device events.
|
||||||
|
Events() <-chan Event
|
||||||
|
|
||||||
|
// Close stops the Device and closes the Event channel.
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// BatchSize returns the preferred/max number of packets that can be read or
|
||||||
|
// written in a single read/write call. BatchSize must not change over the
|
||||||
|
// lifetime of a Device.
|
||||||
|
BatchSize() int
|
||||||
|
}
|
||||||
664
wgstack/tun/tun_linux.go
Normal file
664
wgstack/tun/tun_linux.go
Normal file
@@ -0,0 +1,664 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
|
package tun
|
||||||
|
|
||||||
|
/* Implementation of the TUN device interface for linux
|
||||||
|
*/
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cloneDevicePath = "/dev/net/tun"
|
||||||
|
ifReqSize = unix.IFNAMSIZ + 64
|
||||||
|
)
|
||||||
|
|
||||||
|
type NativeTun struct {
|
||||||
|
tunFile *os.File
|
||||||
|
index int32 // if index
|
||||||
|
errors chan error // async error handling
|
||||||
|
events chan Event // device related events
|
||||||
|
netlinkSock int
|
||||||
|
netlinkCancel *rwcancel.RWCancel
|
||||||
|
hackListenerClosed sync.Mutex
|
||||||
|
statusListenersShutdown chan struct{}
|
||||||
|
batchSize int
|
||||||
|
vnetHdr bool
|
||||||
|
|
||||||
|
closeOnce sync.Once
|
||||||
|
|
||||||
|
nameOnce sync.Once // guards calling initNameCache, which sets following fields
|
||||||
|
nameCache string // name of interface
|
||||||
|
nameErr error
|
||||||
|
|
||||||
|
readOpMu sync.Mutex // readOpMu guards readBuff
|
||||||
|
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
|
||||||
|
|
||||||
|
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
|
||||||
|
toWrite []int
|
||||||
|
tcp4GROTable, tcp6GROTable *tcpGROTable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) File() *os.File {
|
||||||
|
return tun.tunFile
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) routineHackListener() {
|
||||||
|
defer tun.hackListenerClosed.Unlock()
|
||||||
|
/* This is needed for the detection to work across network namespaces
|
||||||
|
* If you are reading this and know a better method, please get in touch.
|
||||||
|
*/
|
||||||
|
last := 0
|
||||||
|
const (
|
||||||
|
up = 1
|
||||||
|
down = 2
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
sysconn, err := tun.tunFile.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err2 := sysconn.Control(func(fd uintptr) {
|
||||||
|
_, err = unix.Write(int(fd), nil)
|
||||||
|
})
|
||||||
|
if err2 != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch err {
|
||||||
|
case unix.EINVAL:
|
||||||
|
if last != up {
|
||||||
|
// If the tunnel is up, it reports that write() is
|
||||||
|
// allowed but we provided invalid data.
|
||||||
|
tun.events <- EventUp
|
||||||
|
last = up
|
||||||
|
}
|
||||||
|
case unix.EIO:
|
||||||
|
if last != down {
|
||||||
|
// If the tunnel is down, it reports that no I/O
|
||||||
|
// is possible, without checking our provided data.
|
||||||
|
tun.events <- EventDown
|
||||||
|
last = down
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
// nothing
|
||||||
|
case <-tun.statusListenersShutdown:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNetlinkSocket() (int, error) {
|
||||||
|
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
saddr := &unix.SockaddrNetlink{
|
||||||
|
Family: unix.AF_NETLINK,
|
||||||
|
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
|
||||||
|
}
|
||||||
|
err = unix.Bind(sock, saddr)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return sock, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) routineNetlinkListener() {
|
||||||
|
defer func() {
|
||||||
|
unix.Close(tun.netlinkSock)
|
||||||
|
tun.hackListenerClosed.Lock()
|
||||||
|
close(tun.events)
|
||||||
|
tun.netlinkCancel.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for msg := make([]byte, 1<<16); ; {
|
||||||
|
var err error
|
||||||
|
var msgn int
|
||||||
|
for {
|
||||||
|
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
|
||||||
|
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !tun.netlinkCancel.ReadyRead() {
|
||||||
|
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-tun.statusListenersShutdown:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
wasEverUp := false
|
||||||
|
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||||
|
|
||||||
|
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||||
|
|
||||||
|
if int(hdr.Len) > len(remain) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch hdr.Type {
|
||||||
|
case unix.NLMSG_DONE:
|
||||||
|
remain = []byte{}
|
||||||
|
|
||||||
|
case unix.RTM_NEWLINK:
|
||||||
|
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
|
||||||
|
remain = remain[hdr.Len:]
|
||||||
|
|
||||||
|
if info.Index != tun.index {
|
||||||
|
// not our interface
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||||
|
tun.events <- EventUp
|
||||||
|
wasEverUp = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||||
|
// Don't emit EventDown before we've ever emitted EventUp.
|
||||||
|
// This avoids a startup race with HackListener, which
|
||||||
|
// might detect Up before we have finished reporting Down.
|
||||||
|
if wasEverUp {
|
||||||
|
tun.events <- EventDown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tun.events <- EventMTUUpdate
|
||||||
|
|
||||||
|
default:
|
||||||
|
remain = remain[hdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIFIndex(name string) (int32, error) {
|
||||||
|
fd, err := unix.Socket(
|
||||||
|
unix.AF_INET,
|
||||||
|
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer unix.Close(fd)
|
||||||
|
|
||||||
|
var ifr [ifReqSize]byte
|
||||||
|
copy(ifr[:], name)
|
||||||
|
_, _, errno := unix.Syscall(
|
||||||
|
unix.SYS_IOCTL,
|
||||||
|
uintptr(fd),
|
||||||
|
uintptr(unix.SIOCGIFINDEX),
|
||||||
|
uintptr(unsafe.Pointer(&ifr[0])),
|
||||||
|
)
|
||||||
|
|
||||||
|
if errno != 0 {
|
||||||
|
return 0, errno
|
||||||
|
}
|
||||||
|
|
||||||
|
return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) setMTU(n int) error {
|
||||||
|
name, err := tun.Name()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// open datagram socket
|
||||||
|
fd, err := unix.Socket(
|
||||||
|
unix.AF_INET,
|
||||||
|
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer unix.Close(fd)
|
||||||
|
|
||||||
|
var ifr [ifReqSize]byte
|
||||||
|
copy(ifr[:], name)
|
||||||
|
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
||||||
|
|
||||||
|
_, _, errno := unix.Syscall(
|
||||||
|
unix.SYS_IOCTL,
|
||||||
|
uintptr(fd),
|
||||||
|
uintptr(unix.SIOCSIFMTU),
|
||||||
|
uintptr(unsafe.Pointer(&ifr[0])),
|
||||||
|
)
|
||||||
|
|
||||||
|
if errno != 0 {
|
||||||
|
return errno
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) routineNetlinkRead() {
|
||||||
|
defer func() {
|
||||||
|
unix.Close(tun.netlinkSock)
|
||||||
|
tun.hackListenerClosed.Lock()
|
||||||
|
close(tun.events)
|
||||||
|
tun.netlinkCancel.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for msg := make([]byte, 1<<16); ; {
|
||||||
|
var err error
|
||||||
|
var msgn int
|
||||||
|
for {
|
||||||
|
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
|
||||||
|
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !tun.netlinkCancel.ReadyRead() {
|
||||||
|
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wasEverUp := false
|
||||||
|
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||||
|
|
||||||
|
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||||
|
|
||||||
|
if int(hdr.Len) > len(remain) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch hdr.Type {
|
||||||
|
case unix.NLMSG_DONE:
|
||||||
|
remain = []byte{}
|
||||||
|
|
||||||
|
case unix.RTM_NEWLINK:
|
||||||
|
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
|
||||||
|
remain = remain[hdr.Len:]
|
||||||
|
|
||||||
|
if info.Index != tun.index {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||||
|
tun.events <- EventUp
|
||||||
|
wasEverUp = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||||
|
if wasEverUp {
|
||||||
|
tun.events <- EventDown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tun.events <- EventMTUUpdate
|
||||||
|
|
||||||
|
default:
|
||||||
|
remain = remain[hdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) routineNetlink() {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
tun.netlinkSock, err = createNetlinkSocket()
|
||||||
|
if err != nil {
|
||||||
|
tun.errors <- fmt.Errorf("failed to create netlink socket: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
|
||||||
|
if err != nil {
|
||||||
|
tun.errors <- fmt.Errorf("failed to create netlink cancel: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go tun.routineNetlinkListener()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) Close() error {
|
||||||
|
var err1, err2 error
|
||||||
|
tun.closeOnce.Do(func() {
|
||||||
|
if tun.statusListenersShutdown != nil {
|
||||||
|
close(tun.statusListenersShutdown)
|
||||||
|
if tun.netlinkCancel != nil {
|
||||||
|
err1 = tun.netlinkCancel.Cancel()
|
||||||
|
}
|
||||||
|
} else if tun.events != nil {
|
||||||
|
close(tun.events)
|
||||||
|
}
|
||||||
|
err2 = tun.tunFile.Close()
|
||||||
|
})
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) BatchSize() int {
|
||||||
|
return tun.batchSize
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TODO: support TSO with ECN bits
|
||||||
|
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tun *NativeTun) initFromFlags(name string) error {
|
||||||
|
sc, err := tun.tunFile.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if e := sc.Control(func(fd uintptr) {
|
||||||
|
var (
|
||||||
|
ifr *unix.Ifreq
|
||||||
|
)
|
||||||
|
ifr, err = unix.NewIfreq(name)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
got := ifr.Uint16()
|
||||||
|
if got&unix.IFF_VNET_HDR != 0 {
|
||||||
|
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tun.vnetHdr = true
|
||||||
|
tun.batchSize = wgconn.IdealBatchSize
|
||||||
|
} else {
|
||||||
|
tun.batchSize = 1
|
||||||
|
}
|
||||||
|
}); e != nil {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTUN creates a Device with the provided name and MTU.
|
||||||
|
func CreateTUN(name string, mtu int) (Device, error) {
|
||||||
|
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
|
||||||
|
}
|
||||||
|
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
|
||||||
|
tun, err := CreateTUNFromFile(fd, mtu)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if name != "tun" {
|
||||||
|
if err := tun.(*NativeTun).initFromFlags(name); err != nil {
|
||||||
|
tun.Close()
|
||||||
|
return nil, fmt.Errorf("CreateTUN(%q) failed to set flags: %w", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tun, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
|
||||||
|
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||||
|
tun := &NativeTun{
|
||||||
|
tunFile: file,
|
||||||
|
errors: make(chan error, 5),
|
||||||
|
events: make(chan Event, 5),
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := tun.Name()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to determine TUN name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tun.initFromFlags(name); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query TUN flags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tun.batchSize == 0 {
|
||||||
|
tun.batchSize = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
tun.index, err = getIFIndex(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get TUN index: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tun.setMTU(mtu); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set MTU: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tun.statusListenersShutdown = make(chan struct{})
|
||||||
|
go tun.routineNetlink()
|
||||||
|
|
||||||
|
if tun.batchSize == 0 {
|
||||||
|
tun.batchSize = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
tun.tcp4GROTable = newTCPGROTable()
|
||||||
|
tun.tcp6GROTable = newTCPGROTable()
|
||||||
|
|
||||||
|
return tun, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) Name() (string, error) {
|
||||||
|
tun.nameOnce.Do(tun.initNameCache)
|
||||||
|
return tun.nameCache, tun.nameErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) initNameCache() {
|
||||||
|
sysconn, err := tun.tunFile.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
tun.nameErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = sysconn.Control(func(fd uintptr) {
|
||||||
|
var ifr [ifReqSize]byte
|
||||||
|
_, _, errno := unix.Syscall(
|
||||||
|
unix.SYS_IOCTL,
|
||||||
|
fd,
|
||||||
|
uintptr(unix.TUNGETIFF),
|
||||||
|
uintptr(unsafe.Pointer(&ifr[0])),
|
||||||
|
)
|
||||||
|
if errno != 0 {
|
||||||
|
tun.nameErr = errno
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tun.nameCache = unix.ByteSliceToString(ifr[:])
|
||||||
|
})
|
||||||
|
if err != nil && tun.nameErr == nil {
|
||||||
|
tun.nameErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) MTU() (int, error) {
|
||||||
|
name, err := tun.Name()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// open datagram socket
|
||||||
|
fd, err := unix.Socket(
|
||||||
|
unix.AF_INET,
|
||||||
|
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer unix.Close(fd)
|
||||||
|
|
||||||
|
var ifr [ifReqSize]byte
|
||||||
|
copy(ifr[:], name)
|
||||||
|
|
||||||
|
_, _, errno := unix.Syscall(
|
||||||
|
unix.SYS_IOCTL,
|
||||||
|
uintptr(fd),
|
||||||
|
uintptr(unix.SIOCGIFMTU),
|
||||||
|
uintptr(unsafe.Pointer(&ifr[0])),
|
||||||
|
)
|
||||||
|
|
||||||
|
if errno != 0 {
|
||||||
|
return 0, errno
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) Events() <-chan Event {
|
||||||
|
return tun.events
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
|
tun.writeOpMu.Lock()
|
||||||
|
defer func() {
|
||||||
|
tun.tcp4GROTable.reset()
|
||||||
|
tun.tcp6GROTable.reset()
|
||||||
|
tun.writeOpMu.Unlock()
|
||||||
|
}()
|
||||||
|
var (
|
||||||
|
errs error
|
||||||
|
total int
|
||||||
|
)
|
||||||
|
tun.toWrite = tun.toWrite[:0]
|
||||||
|
if tun.vnetHdr {
|
||||||
|
err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
offset -= virtioNetHdrLen
|
||||||
|
} else {
|
||||||
|
for i := range bufs {
|
||||||
|
tun.toWrite = append(tun.toWrite, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, bufsI := range tun.toWrite {
|
||||||
|
n, err := tun.tunFile.Write(bufs[bufsI][offset:])
|
||||||
|
if errors.Is(err, syscall.EBADFD) {
|
||||||
|
return total, os.ErrClosed
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
errs = errors.Join(errs, err)
|
||||||
|
} else {
|
||||||
|
total += n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total, errs
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
|
||||||
|
// each buffer. It mutates sizes to reflect the size of each element of bufs,
|
||||||
|
// and returns the number of packets read.
|
||||||
|
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||||
|
var hdr virtioNetHdr
|
||||||
|
if err := hdr.decode(in); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
in = in[virtioNetHdrLen:]
|
||||||
|
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
||||||
|
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
||||||
|
if err := gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(in) > len(bufs[0][offset:]) {
|
||||||
|
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
|
||||||
|
}
|
||||||
|
n := copy(bufs[0][offset:], in)
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||||
|
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipVersion := in[0] >> 4
|
||||||
|
switch ipVersion {
|
||||||
|
case 4:
|
||||||
|
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||||
|
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||||
|
}
|
||||||
|
case 6:
|
||||||
|
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||||
|
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(in) <= int(hdr.csumStart+12) {
|
||||||
|
return 0, errors.New("packet is too short")
|
||||||
|
}
|
||||||
|
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
|
||||||
|
if tcpHLen < 20 || tcpHLen > 60 {
|
||||||
|
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||||
|
}
|
||||||
|
hdr.hdrLen = hdr.csumStart + tcpHLen
|
||||||
|
if len(in) < int(hdr.hdrLen) {
|
||||||
|
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
|
||||||
|
}
|
||||||
|
if hdr.hdrLen < hdr.csumStart {
|
||||||
|
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
|
||||||
|
}
|
||||||
|
cSumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||||
|
if cSumAt+1 >= len(in) {
|
||||||
|
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
||||||
|
}
|
||||||
|
|
||||||
|
return tcpTSO(in, hdr, bufs, sizes, offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||||
|
tun.readOpMu.Lock()
|
||||||
|
defer tun.readOpMu.Unlock()
|
||||||
|
select {
|
||||||
|
case err := <-tun.errors:
|
||||||
|
return 0, err
|
||||||
|
default:
|
||||||
|
readInto := bufs[0][offset:]
|
||||||
|
if tun.vnetHdr {
|
||||||
|
readInto = tun.readBuff[:]
|
||||||
|
}
|
||||||
|
n, err := tun.tunFile.Read(readInto)
|
||||||
|
if errors.Is(err, syscall.EBADFD) {
|
||||||
|
err = os.ErrClosed
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if tun.vnetHdr {
|
||||||
|
return handleVirtioRead(readInto[:n], bufs, sizes, offset)
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user