Merge remote-tracking branch 'origin/master' into fips140

This commit is contained in:
Wade Simmons
2026-06-01 09:52:57 -04:00
95 changed files with 5607 additions and 1384 deletions
+113
View File
@@ -0,0 +1,113 @@
name: Code-sign Windows binaries
description: >
Sign every .exe under a given path in place via the DefinedNet code-signer
Lambda. If `role` or `bucket` is empty, logs a notice and skips signing so
forks and dev branches without AWS access still produce usable builds.
inputs:
path:
description: "Directory whose .exe files should be signed in place"
required: true
role:
description: "IAM role ARN to assume via OIDC; empty disables signing"
required: false
default: ""
bucket:
description: "S3 staging bucket the code-signer Lambda reads from; empty disables signing"
required: false
default: ""
region:
description: "AWS region for the role and Lambda"
required: false
default: "us-east-2"
function-name:
description: "Code-signer Lambda function name"
required: false
default: "code-signer"
key-prefix:
description: "S3 key prefix the caller is authorized to write under"
required: false
default: "code-signing/slackhq/nebula"
runs:
using: composite
steps:
- name: Skip notice
if: inputs.role == '' || inputs.bucket == ''
shell: sh
run: echo "::notice::code-signer role or bucket not set; skipping code signing."
- name: Configure AWS credentials
if: inputs.role != '' && inputs.bucket != ''
uses: aws-actions/configure-aws-credentials@v6
with:
role-to-assume: ${{ inputs.role }}
aws-region: ${{ inputs.region }}
# Default is 12 retries to ride out IAM trust-policy propagation; once
# the role is stable we want a real misconfiguration to fail fast.
retry-max-attempts: 5
- name: Sign .exe files
if: inputs.role != '' && inputs.bucket != ''
shell: sh
env:
SIGN_PATH: ${{ inputs.path }}
BUCKET: ${{ inputs.bucket }}
FUNCTION_NAME: ${{ inputs.function-name }}
KEY_PREFIX: ${{ inputs.key-prefix }}
run: |
set -eu
RUN="${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
find "$SIGN_PATH" -name '*.exe' -print | while read -r path
do
rel=${path#"$SIGN_PATH"/}
file=$(basename "$path")
name=${file%.exe}
prefix="${KEY_PREFIX}/${RUN}"
src="${prefix}/unsigned/${rel}"
dst="${prefix}/signed/${rel}"
echo "::group::Sign ${rel}"
echo "Uploading unsigned to s3://${BUCKET}/${src}"
aws s3 cp --no-progress "$path" "s3://${BUCKET}/${src}" >/dev/null
echo "Invoking ${FUNCTION_NAME} Lambda"
payload=$(jq -nc \
--arg s "$src" \
--arg d "$dst" \
--arg p "$name" \
'{source_key: $s, dest_key: $d, program_name: $p}')
meta=$(aws lambda invoke \
--function-name "$FUNCTION_NAME" \
--cli-binary-format raw-in-base64-out \
--payload "$payload" \
--output json \
/tmp/sign-resp.json)
if echo "$meta" | jq -e '.FunctionError != null' >/dev/null
then
echo "::endgroup::"
echo "::error::code-signer Lambda failed for ${rel}"
cat /tmp/sign-resp.json >&2
exit 1
fi
echo "Downloading signed back to ${path}"
aws s3 cp --no-progress "s3://${BUCKET}/${dst}" "$path" >/dev/null
aws s3 rm "s3://${BUCKET}/${src}" >/dev/null 2>&1 || true
aws s3 rm "s3://${BUCKET}/${dst}" >/dev/null 2>&1 || true
# Sanity-check the bytes we got back actually carry an Authenticode
# signature that this machine can validate end to end.
status=$(powershell -NoProfile -Command "(Get-AuthenticodeSignature -FilePath '$path').Status" | tr -d '\r')
if [ "$status" != "Valid" ]
then
echo "::endgroup::"
echo "::error::${rel} signature status: ${status} (expected Valid)"
exit 1
fi
echo "Signed ${rel} (sha256=$(jq -r '.sha256' /tmp/sign-resp.json), status=${status})"
echo "::endgroup::"
done
-34
View File
@@ -1,34 +0,0 @@
name: gofmt
on:
push:
branches:
- master
pull_request:
paths:
- '.github/workflows/gofmt.yml'
- '**.go'
jobs:
gofmt:
name: Run gofmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Install goimports
run: |
go install golang.org/x/tools/cmd/goimports@latest
- name: gofmt
run: |
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ]
then
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d
exit 1
fi
+18 -8
View File
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release mv build/*.tar.gz release
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v6 uses: actions/upload-artifact@v7
with: with:
name: linux-latest name: linux-latest
path: release path: release
@@ -32,6 +32,9 @@ jobs:
build-windows: build-windows:
name: Build Windows name: Build Windows
runs-on: windows-latest runs-on: windows-latest
permissions:
id-token: write
contents: read
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
@@ -54,8 +57,15 @@ jobs:
mkdir build\dist\windows mkdir build\dist\windows
mv dist\windows\wintun build\dist\windows\ mv dist\windows\wintun build\dist\windows\
- name: Code-sign
uses: ./.github/actions/code-sign
with:
path: build
role: ${{ secrets.DEFINED_CODE_SIGNER_ROLE }}
bucket: ${{ secrets.DEFINED_CODE_SIGNER_BUCKET }}
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v6 uses: actions/upload-artifact@v7
with: with:
name: windows-latest name: windows-latest
path: build path: build
@@ -75,7 +85,7 @@ jobs:
- name: Import certificates - name: Import certificates
if: env.HAS_SIGNING_CREDS == 'true' if: env.HAS_SIGNING_CREDS == 'true'
uses: Apple-Actions/import-codesign-certs@v6 uses: Apple-Actions/import-codesign-certs@v7
with: with:
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
@@ -104,7 +114,7 @@ jobs:
fi fi
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v6 uses: actions/upload-artifact@v7
with: with:
name: darwin-latest name: darwin-latest
path: ./release/* path: ./release/*
@@ -128,21 +138,21 @@ jobs:
- name: Download artifacts - name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/download-artifact@v7 uses: actions/download-artifact@v8
with: with:
name: linux-latest name: linux-latest
path: artifacts path: artifacts
- name: Login to Docker Hub - name: Login to Docker Hub
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: docker/login-action@v3 uses: docker/login-action@v4
with: with:
username: ${{ vars.DOCKERHUB_USERNAME }} username: ${{ vars.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v4
- name: Build and push images - name: Build and push images
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
@@ -163,7 +173,7 @@ jobs:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
- name: Download artifacts - name: Download artifacts
uses: actions/download-artifact@v7 uses: actions/download-artifact@v8
with: with:
path: artifacts path: artifacts
+81 -16
View File
@@ -14,10 +14,18 @@ on:
- 'go.sum' - 'go.sum'
jobs: jobs:
smoke-extra: smoke-extra-libvirt:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: Run extra smoke tests name: ${{ matrix.target }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
target:
- freebsd-amd64
- openbsd-amd64
- netbsd-amd64
- linux-amd64-ipv6disable
env: env:
VAGRANT_DEFAULT_PROVIDER: libvirt VAGRANT_DEFAULT_PROVIDER: libvirt
steps: steps:
@@ -40,28 +48,85 @@ jobs:
sudo chmod 666 /var/run/libvirt/libvirt-sock sudo chmod 666 /var/run/libvirt/libvirt-sock
vagrant plugin install vagrant-libvirt vagrant plugin install vagrant-libvirt
- name: freebsd-amd64 - name: ${{ matrix.target }}
run: make smoke-vagrant/freebsd-amd64 run: make smoke-vagrant/${{ matrix.target }}
- name: openbsd-amd64 timeout-minutes: 30
run: make smoke-vagrant/openbsd-amd64
- name: netbsd-amd64 # linux-386 needs VirtualBox, which conflicts with KVM/libvirt -- isolated job.
run: make smoke-vagrant/netbsd-amd64 smoke-extra-virtualbox:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: linux-386
runs-on: ubuntu-latest
env:
VAGRANT_DEFAULT_PROVIDER: virtualbox
steps:
- name: linux-amd64-ipv6disable - uses: actions/checkout@v6
run: make smoke-vagrant/linux-amd64-ipv6disable
# linux-386 runs last because it requires disabling KVM to use VirtualBox, - uses: actions/setup-go@v6
# which prevents libvirt (used by the other tests) from working after this point. with:
- name: install virtualbox for i386 test go-version: '1.25'
check-latest: true
- name: add hashicorp source
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
- name: install vagrant and virtualbox
run: | run: |
sudo apt-get install -y virtualbox sudo apt-get update && sudo apt-get install -y vagrant virtualbox
sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true
- name: linux-386 - name: linux-386
env:
VAGRANT_DEFAULT_PROVIDER: virtualbox
run: make smoke-vagrant/linux-386 run: make smoke-vagrant/linux-386
timeout-minutes: 30 timeout-minutes: 30
smoke-windows:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: Run windows smoke test
runs-on: windows-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
# WSL2 + Ubuntu so the smoke can run a real linux peer with its own
# netns. iputils-ping is needed for the in-WSL ping check. WSL1 has no
# real kernel and would lack /dev/net/tun, so we have to force WSL2.
- uses: Vampire/setup-wsl@v3
with:
distribution: Ubuntu-24.04
additional-packages: iputils-ping iproute2
# Vampire/setup-wsl provisions WSL1 even when the WSL2 platform is present.
# Convert the distro to WSL2 explicitly before we try to use /dev/net/tun.
- name: convert distro to WSL2
shell: pwsh
run: |
wsl --set-version Ubuntu-24.04 2
wsl --shutdown
wsl --list --verbose
- name: build windows nebula
run: make bin-windows
- name: build linux nebula for WSL
shell: bash
env:
GOOS: linux
GOARCH: amd64
run: |
mkdir -p build/linux-amd64
go build -o build/linux-amd64/nebula ./cmd/nebula
- name: run smoke-windows
shell: pwsh
working-directory: ./.github/workflows/smoke
run: ./smoke-windows.ps1
timeout-minutes: 15
+272
View File
@@ -0,0 +1,272 @@
#!/usr/bin/env pwsh
# Windows smoke test for the nebula tun + UDP + NLM code paths.
#
# Topology:
# - lighthouse runs natively on the Windows host (wintun + windows UDP)
# - peer runs inside WSL2 (Linux build of nebula, /dev/net/tun)
#
# WSL2 gives us a real netns boundary so the loopback fast-path on Windows
# does not short-circuit the overlay -- when WSL pings the lighthouse VPN IP,
# Linux has no idea that IP is local to the Windows host, so the packet is
# forced through nebula. Same in reverse.
$ErrorActionPreference = 'Stop'
# wsl.exe emits UTF-16 LE by default which PowerShell reads as bytes, mangling
# every captured string. WSL_UTF8 makes wsl.exe emit UTF-8 instead.
$env:WSL_UTF8 = '1'
$RepoRoot = Resolve-Path "$PSScriptRoot\..\..\.."
$Nebula = Join-Path $RepoRoot 'nebula.exe'
$NebulaCert = Join-Path $RepoRoot 'nebula-cert.exe'
$NebulaLinux = Join-Path $RepoRoot 'build\linux-amd64\nebula'
if (-not (Test-Path $Nebula)) { throw "missing $Nebula; run 'make bin-windows' first" }
if (-not (Test-Path $NebulaCert)) { throw "missing $NebulaCert; run 'make bin-windows' first" }
if (-not (Test-Path $NebulaLinux)) { throw "missing $NebulaLinux; build the linux nebula first" }
# Matches the distro installed by Vampire/setup-wsl in smoke-extra.yml.
$Distro = 'Ubuntu-24.04'
$listed = (wsl --list --quiet 2>$null) -join "`n"
if ($listed -notmatch [regex]::Escape($Distro)) {
throw "WSL distro $Distro not registered. Got: $listed"
}
Write-Host "Using WSL distro: $Distro"
# Windows host as seen from inside WSL: WSL's default-route gateway. We extract
# it with a regex rather than awk fields so PowerShell does not eat any '$N'
# tokens, and tabs/double-spaces in `ip route` output do not confuse a cut.
$ipCmd = 'ip route show default | grep -oE "([0-9]+\.){3}[0-9]+" | head -1'
$WindowsIp = (wsl -d $Distro -- bash -c $ipCmd).Trim()
if (-not $WindowsIp) { throw "could not determine Windows host IP from WSL" }
Write-Host "Windows host IP from WSL: $WindowsIp"
$WorkDir = Join-Path $env:TEMP 'nebula-smoke-windows'
if (Test-Path $WorkDir) { Remove-Item -Recurse -Force $WorkDir }
New-Item -ItemType Directory -Path $WorkDir | Out-Null
$WslDir = '/tmp/nebula-smoke'
wsl -d $Distro -- bash -c "rm -rf $WslDir && mkdir -p $WslDir" | Out-Null
$DevName = 'nebula-smoke'
$Ip1 = '192.168.241.1'
$Ip2 = '192.168.241.2'
$Port = 4242
& $NebulaCert ca -name 'smoke-ca' -out-crt "$WorkDir\ca.crt" -out-key "$WorkDir\ca.key"
if ($LASTEXITCODE -ne 0) { throw "nebula-cert ca failed (exit $LASTEXITCODE)" }
& $NebulaCert sign -name 'lighthouse' -networks "$Ip1/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\lighthouse.crt" -out-key "$WorkDir\lighthouse.key"
if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign lighthouse failed (exit $LASTEXITCODE)" }
& $NebulaCert sign -name 'peer' -networks "$Ip2/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\peer.crt" -out-key "$WorkDir\peer.key"
if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign peer failed (exit $LASTEXITCODE)" }
# Windows lighthouse config.
@"
pki:
ca: $WorkDir\ca.crt
cert: $WorkDir\lighthouse.crt
key: $WorkDir\lighthouse.key
static_host_map: {}
lighthouse:
am_lighthouse: true
interval: 60
hosts: []
listen:
host: 0.0.0.0
port: $Port
tun:
disabled: false
dev: $DevName
drop_local_broadcast: false
drop_multicast: false
tx_queue: 500
mtu: 1300
network_category: private
logging:
level: info
format: text
firewall:
outbound_action: drop
inbound_action: drop
conntrack:
tcp_timeout: 12m
udp_timeout: 3m
default_timeout: 10m
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
"@ | Out-File -FilePath "$WorkDir\lighthouse.yml" -Encoding utf8
# WSL peer config (paths are POSIX, deliberately).
@"
pki:
ca: $WslDir/ca.crt
cert: $WslDir/peer.crt
key: $WslDir/peer.key
static_host_map:
"${Ip1}": ["${WindowsIp}:$Port"]
lighthouse:
am_lighthouse: false
interval: 60
hosts:
- "${Ip1}"
listen:
host: 0.0.0.0
port: 0
tun:
disabled: false
dev: nebula1
drop_local_broadcast: false
drop_multicast: false
tx_queue: 500
mtu: 1300
logging:
level: info
format: text
firewall:
outbound_action: drop
inbound_action: drop
conntrack:
tcp_timeout: 12m
udp_timeout: 3m
default_timeout: 10m
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
"@ | Out-File -FilePath "$WorkDir\peer.yml" -Encoding utf8
# Stage WSL artifacts. Convert Windows paths to WSL paths ourselves rather than
# calling `wslpath`, because PowerShell's argument-passing to external EXEs
# strips backslashes from path arguments in ways that are hard to escape around.
function ConvertTo-WslPath {
param([string]$WindowsPath)
if ($WindowsPath -notmatch '^([A-Za-z]):\\(.*)$') {
throw "cannot convert path to WSL: $WindowsPath"
}
return "/mnt/$($matches[1].ToLower())/$($matches[2].Replace('\','/'))"
}
$WslWorkDir = ConvertTo-WslPath $WorkDir
$WslNebulaPath = ConvertTo-WslPath $NebulaLinux
wsl -d $Distro -- bash -c "cp '$WslWorkDir/ca.crt' '$WslWorkDir/peer.crt' '$WslWorkDir/peer.key' '$WslWorkDir/peer.yml' $WslDir/ && cp '$WslNebulaPath' $WslDir/nebula && chmod +x $WslDir/nebula"
# Make sure WSL has tun support and /dev/net/tun is usable before starting
# nebula. Diagnostics first so a fail here points at the real problem (e.g.
# WSL1 distros do not have a real kernel and will not have tun).
Write-Host '=== WSL diagnostic ==='
wsl --version 2>&1 | Out-Host
wsl --list --verbose 2>&1 | Out-Host
wsl -d $Distro -u root -- uname -a | Out-Host
wsl -d $Distro -u root -- bash -c "modprobe tun 2>&1 || true; mkdir -p /dev/net; [ -c /dev/net/tun ] || mknod /dev/net/tun c 10 200; chmod 600 /dev/net/tun; ls -l /dev/net/tun"
if ($LASTEXITCODE -ne 0) { throw "failed to prepare /dev/net/tun in WSL (TUN support missing?)" }
# Deliberately no New-NetFirewallRule calls here -- nebula's windows_bypass_wdf
# feature is supposed to install WFP permit filters that let inbound traffic
# through Windows Defender Firewall on its own. If this smoke regresses, that
# feature regressed.
$lhOut = Join-Path $WorkDir 'lighthouse.out.log'
$lhErr = Join-Path $WorkDir 'lighthouse.err.log'
$lhProc = Start-Process -FilePath $Nebula -ArgumentList @('-config', "$WorkDir\lighthouse.yml") `
-PassThru -NoNewWindow `
-RedirectStandardOutput $lhOut `
-RedirectStandardError $lhErr
# Run nebula in WSL as root with no sudo + no shell wrapper. PowerShell's
# Start-Process arg quoting mangles `bash -c "..."` strings that contain
# spaces/redirections, so we skip bash entirely and let Start-Process do the
# stdout/stderr capture itself.
$peerOut = Join-Path $WorkDir 'peer.out.log'
$peerErr = Join-Path $WorkDir 'peer.err.log'
$peerProc = Start-Process -FilePath 'wsl' `
-ArgumentList @('-d', $Distro, '-u', 'root', '--', "$WslDir/nebula", '-config', "$WslDir/peer.yml") `
-PassThru -NoNewWindow `
-RedirectStandardOutput $peerOut `
-RedirectStandardError $peerErr
function Wait-Until {
param([scriptblock]$Predicate, [int]$TimeoutSec, [string]$What)
$deadline = (Get-Date).AddSeconds($TimeoutSec)
while ((Get-Date) -lt $deadline) {
if (& $Predicate) { return }
Start-Sleep -Milliseconds 500
}
throw "timed out waiting for: $What"
}
try {
Wait-Until -TimeoutSec 30 -What "windows wintun adapter $DevName with NetworkCategory=Private" -Predicate {
if ($lhProc.HasExited) { throw "lighthouse exited (code $($lhProc.ExitCode)) before tun was ready" }
$p = Get-NetConnectionProfile -InterfaceAlias $DevName -ErrorAction SilentlyContinue
$p -and ("$($p.NetworkCategory)" -ieq 'Private')
}
Write-Host "OK: $DevName NetworkCategory=Private"
Wait-Until -TimeoutSec 30 -What "WSL nebula1 with $Ip2" -Predicate {
if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before tun was ready" }
$r = wsl -d $Distro -u root -- bash -c "ip -o addr show nebula1 2>/dev/null | grep -q 'inet $Ip2' && echo yes"
("$r").Trim() -eq 'yes'
}
Write-Host "OK: WSL nebula1 has $Ip2"
Wait-Until -TimeoutSec 30 -What "ping from WSL peer to windows lighthouse ($Ip1)" -Predicate {
if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before ping succeeded" }
$r = wsl -d $Distro -u root -- bash -c "ping -c1 -W1 $Ip1 >/dev/null 2>&1 && echo OK"
("$r").Trim() -eq 'OK'
}
Write-Host "OK: WSL peer -> windows lighthouse"
Wait-Until -TimeoutSec 30 -What "ping from windows lighthouse to WSL peer ($Ip2)" -Predicate {
$null = & ping.exe -n 1 -w 1000 $Ip2
$LASTEXITCODE -eq 0
}
Write-Host "OK: windows lighthouse -> WSL peer"
Write-Host ''
Write-Host 'All smoke checks passed.'
}
catch {
Write-Host ''
Write-Host '=== lighthouse stdout ==='
Get-Content $lhOut -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== lighthouse stderr ==='
Get-Content $lhErr -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== peer stdout ==='
Get-Content $peerOut -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== peer stderr ==='
Get-Content $peerErr -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== nebula WFP filters ==='
# Dump nebula-installed filters so we can verify they got registered with
# the conditions we expect.
$wfpDump = Join-Path $WorkDir 'wfp.xml'
netsh wfp show filters file=$wfpDump 2>&1 | Out-Null
if (Test-Path $wfpDump) {
Select-String -Path $wfpDump -Pattern 'Nebula' -Context 0,80 -ErrorAction SilentlyContinue | Out-Host
}
throw
}
finally {
if (-not $lhProc.HasExited) {
Stop-Process -Id $lhProc.Id -Force -ErrorAction SilentlyContinue
$lhProc.WaitForExit(5000) | Out-Null
}
wsl -d $Distro -u root -- bash -c "pkill -f $WslDir/nebula 2>/dev/null; true" | Out-Null
# pkill returns 1 when no match and wsl propagates that; the smoke is done
# so we don't want it to leak into the script's exit code.
$global:LASTEXITCODE = 0
if ($peerProc -and -not $peerProc.HasExited) {
Stop-Process -Id $peerProc.Id -Force -ErrorAction SilentlyContinue
}
}
@@ -1,7 +1,7 @@
# -*- mode: ruby -*- # -*- mode: ruby -*-
# vi: set ft=ruby : # vi: set ft=ruby :
Vagrant.configure("2") do |config| Vagrant.configure("2") do |config|
config.vm.box = "generic/netbsd9" config.vm.box = "DefinedNet/netbsd10"
config.vm.synced_folder "../build", "/nebula", type: "rsync" config.vm.synced_folder "../build", "/nebula", type: "rsync"
end end
+100 -98
View File
@@ -13,8 +13,8 @@ on:
- 'go.sum' - 'go.sum'
jobs: jobs:
test-linux: static:
name: Build all and test on ubuntu-linux name: Static checks
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@@ -25,8 +25,16 @@ jobs:
go-version: '1.25' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Install goimports
run: make all run: go install golang.org/x/tools/cmd/goimports@latest
- name: gofmt
run: |
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ]
then
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d
exit 1
fi
- name: Vet - name: Vet
run: make vet run: make vet
@@ -36,87 +44,43 @@ jobs:
with: with:
version: v2.5 version: v2.5
- name: Test
run: make test
- name: End 2 end
run: make e2evv
- name: Build test mobile
run: make build-test-mobile
- uses: actions/upload-artifact@v6
with:
name: e2e packet flow linux-latest
path: e2e/mermaid/linux-latest
if-no-files-found: warn
test-linux-boringcrypto:
name: Build and test on linux with boringcrypto
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build
run: make bin-boringcrypto
- name: Test
run: make test-boringcrypto
- name: End 2 end
run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
test-linux-fips140:
name: Build and test on linux with fips140=on
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build
run: make fips140
- name: Test
run: make fips140 test
- name: End 2 end
run: make fips140 e2evv
test-linux-pkcs11:
name: Build and test on linux with pkcs11
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build
run: make bin-pkcs11
- name: Test
run: make test-pkcs11
test: test:
name: Build and test on ${{ matrix.os }} name: Test ${{ matrix.name }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
fail-fast: false
matrix: matrix:
os: [windows-latest, macos-latest] include:
- name: linux
os: ubuntu-latest
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
test-cmd: make test
e2e-cmd: make e2evv
- name: linux-boringcrypto
os: ubuntu-latest
build-cmd: make bin-boringcrypto
test-cmd: make test-boringcrypto
e2e-cmd: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
- name: linux-fips140
os: ubuntu-latest
build-cmd: make fips140
test-cmd: make fips140 test
e2e-cmd: make fips140 e2evv
- name: linux-pkcs11
os: ubuntu-latest
build-cmd: make bin-pkcs11
test-cmd: make test-pkcs11
e2e-cmd: ''
- name: macos
os: macos-latest
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
test-cmd: make test
e2e-cmd: make e2evv
- name: windows
os: windows-latest
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
test-cmd: make test
e2e-cmd: make e2evv
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
@@ -126,28 +90,66 @@ jobs:
go-version: '1.25' go-version: '1.25'
check-latest: true check-latest: true
- name: Build nebula - name: Build
run: go build ./cmd/nebula run: ${{ matrix.build-cmd }}
- name: Build nebula-cert - name: Cross-build darwin-amd64
run: go build ./cmd/nebula-cert if: matrix.name == 'macos'
run: GOARCH=amd64 go build -o /tmp/nebula-amd64 ./cmd/nebula && GOARCH=amd64 go build -o /tmp/nebula-cert-amd64 ./cmd/nebula-cert
- name: Vet
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.5
- name: Test - name: Test
run: make test run: ${{ matrix.test-cmd }}
- name: End 2 end - name: End 2 end
run: make e2evv if: matrix.e2e-cmd != ''
run: ${{ matrix.e2e-cmd }}
- uses: actions/upload-artifact@v6 - uses: actions/upload-artifact@v7
if: matrix.e2e-cmd != '' && always()
with: with:
name: e2e packet flow ${{ matrix.os }} name: e2e packet flow ${{ matrix.name }}
path: e2e/mermaid/${{ matrix.os }} path: e2e/mermaid/
if-no-files-found: warn if-no-files-found: warn
cross-build:
name: Cross-build ${{ matrix.name }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- {name: linux-arm, make-target: all-cross-linux-arm}
- {name: linux-mips, make-target: all-cross-linux-mips}
- {name: linux-other, make-target: all-cross-linux-other}
- {name: freebsd, make-target: all-freebsd}
- {name: openbsd, make-target: all-openbsd}
- {name: netbsd, make-target: all-netbsd}
- {name: windows, make-target: all-cross-windows}
- {name: mobile, make-target: build-test-mobile}
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build ${{ matrix.name }}
run: make -j"$(nproc)" ${{ matrix.make-target }}
finish:
name: CI status
if: always()
needs: [static, test, cross-build]
runs-on: ubuntu-latest
steps:
- name: Fail if any upstream job failed
if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled')
run: |
echo "upstream results: ${{ toJSON(needs) }}"
exit 1
- name: All upstream jobs passed
run: echo "ok"
+42 -1
View File
@@ -60,6 +60,18 @@ ALL = $(ALL_LINUX) \
windows-amd64 \ windows-amd64 \
windows-arm64 windows-arm64
# Cross-build shards used by .github/workflows/test.yml — same as ALL_*
# but with the arch that has a native CI runner removed, so the cross-build
# job is not duplicating coverage the native test jobs already give.
ALL_CROSS_LINUX = $(filter-out linux-amd64,$(ALL_LINUX))
# ALL_CROSS_LINUX further split into family sub-shards so each can run on
# its own CI runner in parallel. Union of the three must equal
# ALL_CROSS_LINUX; adding a new linux arch goes into the matching family.
ALL_CROSS_LINUX_ARM = linux-arm-5 linux-arm-6 linux-arm-7 linux-arm64
ALL_CROSS_LINUX_MIPS = linux-mips linux-mipsle linux-mips64 linux-mips64le linux-mips-softfloat
ALL_CROSS_LINUX_OTHER = linux-386 linux-ppc64le linux-riscv64 linux-loong64
e2e: e2e:
$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e $(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
@@ -82,6 +94,35 @@ DOCKER_BIN = build/linux-amd64/nebula build/linux-amd64/nebula-cert
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert) all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
all-linux: $(ALL_LINUX:%=build/%/nebula) $(ALL_LINUX:%=build/%/nebula-cert)
all-freebsd: $(ALL_FREEBSD:%=build/%/nebula) $(ALL_FREEBSD:%=build/%/nebula-cert)
all-openbsd: $(ALL_OPENBSD:%=build/%/nebula) $(ALL_OPENBSD:%=build/%/nebula-cert)
all-netbsd: $(ALL_NETBSD:%=build/%/nebula) $(ALL_NETBSD:%=build/%/nebula-cert)
all-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert build/darwin-arm64/nebula build/darwin-arm64/nebula-cert
all-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe
# CI cross-build shards. darwin-arm64 is covered by the native macos-latest
# job; windows-amd64 is covered by the native windows-latest job; both are
# omitted here to avoid building them a second time. darwin-amd64 stays in
# all-cross-darwin because intel mac is only a labeled/master-time native
# job, so PRs still need cross-build coverage for it.
all-cross-linux: $(ALL_CROSS_LINUX:%=build/%/nebula) $(ALL_CROSS_LINUX:%=build/%/nebula-cert)
all-cross-linux-arm: $(ALL_CROSS_LINUX_ARM:%=build/%/nebula) $(ALL_CROSS_LINUX_ARM:%=build/%/nebula-cert)
all-cross-linux-mips: $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula) $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula-cert)
all-cross-linux-other: $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula) $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula-cert)
all-cross-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
all-cross-windows: build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe
docker: docker/linux-$(shell go env GOARCH) docker: docker/linux-$(shell go env GOARCH)
release: $(ALL:%=build/nebula-%.tar.gz) release: $(ALL:%=build/nebula-%.tar.gz)
@@ -267,5 +308,5 @@ smoke-vagrant/%: bin-docker build/%/nebula
cd .github/workflows/smoke/ && ./smoke-vagrant.sh $* cd .github/workflows/smoke/ && ./smoke-vagrant.sh $*
.FORCE: .FORCE:
.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv fips140 proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/% .PHONY: all all-linux all-freebsd all-openbsd all-netbsd all-darwin all-windows all-cross-linux all-cross-linux-arm all-cross-linux-mips all-cross-linux-other all-cross-darwin all-cross-windows bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv fips140 proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/%
.DEFAULT_GOAL := bin .DEFAULT_GOAL := bin
+174 -33
View File
@@ -2,24 +2,42 @@ package nebula
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"math"
mathbits "math/bits"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
) )
const bitsPerWord = 64
// Bits is a sliding-window anti-replay tracker. The window is stored as a
// circular bitmap packed into uint64 words (8x denser than a []bool), so a
// length-N window costs N/8 bytes. length must be a power of two.
type Bits struct { type Bits struct {
length uint64 length uint64
lengthMask uint64
current uint64 current uint64
bits []bool bits []uint64
lostCounter metrics.Counter lostCounter metrics.Counter
dupeCounter metrics.Counter dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter outOfWindowCounter metrics.Counter
} }
func NewBits(bits uint64) *Bits { func NewBits(length uint64) *Bits {
if length == 0 || length&(length-1) != 0 {
panic(fmt.Sprintf("Bits length must be a power of two, got %d", length))
}
nWords := length / bitsPerWord
if nWords == 0 {
nWords = 1
}
b := &Bits{ b := &Bits{
length: bits, length: length,
bits: make([]bool, bits, bits), lengthMask: length - 1,
bits: make([]uint64, nWords),
current: 0, current: 0,
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil), lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
@@ -27,71 +45,194 @@ func NewBits(bits uint64) *Bits {
} }
// There is no counter value 0, mark it to avoid counting a lost packet later. // There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true b.bits[0] = 1
b.current = 0
return b return b
} }
func (b *Bits) get(i uint64) bool {
pos := i & b.lengthMask
//bit-shifting by 6 because i is a bit index, not a u64 index, and we need to find the u64 without bit in it
return b.bits[pos>>6]&(uint64(1)<<(pos&63)) != 0
}
func (b *Bits) set(i uint64) {
pos := i & b.lengthMask
b.bits[pos>>6] |= uint64(1) << (pos & 63)
}
// clearRange clears `count` bits starting at circular position `startPos`
// (already masked to [0, length)) and returns how many of them were set
// before the clear. count must be in [1, length].
func (b *Bits) clearRange(startPos, count uint64) uint64 {
wasSet := uint64(0)
if count >= b.length {
for _, w := range b.bits {
wasSet += uint64(mathbits.OnesCount64(w))
}
clear(b.bits)
return wasSet
}
pos := startPos
remaining := count
// handle the potential partial word before pos becomes u64 aligned
word := pos >> 6
bit := pos & 63
take := uint64(64) - bit
if take > remaining {
take = remaining
}
if take > b.length-pos {
take = b.length - pos
}
var mask uint64
if take == 64 {
mask = math.MaxUint64
} else {
mask = ((uint64(1) << take) - 1) << bit
}
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
b.bits[word] &^= mask
remaining -= take
pos = (pos + take) & b.lengthMask
// Clear whole words, keeping track of the number of set bits
for remaining >= 64 {
word = pos >> 6
wasSet += uint64(mathbits.OnesCount64(b.bits[word]))
b.bits[word] = 0
remaining -= 64
pos = (pos + 64) & b.lengthMask
}
// Clear the remaining partial word
if remaining > 0 {
word = pos >> 6
mask = (uint64(1) << remaining) - 1
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
b.bits[word] &^= mask
}
return wasSet
}
func (b *Bits) strictlyWithinWindow(i uint64) bool {
// Handle the case where the window hasn't slid yet. This avoids u64 underflow.
inWarmup := b.current < b.length
if i < b.length && inWarmup {
return true
}
// Next, if the packet is in-window, see if we've seen it before
if i > b.current-b.length {
return true
}
return false //not within window!
}
// Check returns true if i is within (or way out in front of) the window, and not a replay
func (b *Bits) Check(l *slog.Logger, i uint64) bool { func (b *Bits) Check(l *slog.Logger, i uint64) bool {
// If i is the next number, return true. // If i is the next number, return true.
if i > b.current { if i > b.current {
return true return true
} }
// If i is within the window, check if it's been set already. if b.strictlyWithinWindow(i) {
if i > b.current-b.length || i < b.length && b.current < b.length { return !b.get(i)
return !b.bits[i%b.length]
} }
// Not within the window // Not within the window
if l.Enabled(context.Background(), slog.LevelDebug) { if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("rejected a packet (top)", l.Debug("rejected a packet (top)", "current", b.current, "incoming", i)
"current", b.current,
"incoming", i,
)
} }
return false return false
} }
// Update has three branches:
// - i == b.current+1: fast path; advance the cursor by one and lose-count
// the slot we just stomped (only past warmup; see the i > b.length guard
// below).
// - i > b.current+1: jump path; clear all slots between current and i
// (or up to a full window's worth, whichever is smaller) via clearRange,
// then mark i. Two arms here: a warmup arm that handles the very first
// window before the cursor has slid, and a steady-state arm that treats
// every cleared empty slot as a lost packet.
// - i <= b.current: in-window check for duplicates; out-of-window otherwise.
//
// NewBits seeds bits[0]=1 so counter 0 looks "received" — Update never
// clears that marker during warmup (clearRange skips position 0 when
// startPos=1), and once b.current >= b.length the marker is no longer
// consulted. The marker prevents a fictitious "lost" hit on the first real
// counter.
func (b *Bits) Update(l *slog.Logger, i uint64) bool { func (b *Bits) Update(l *slog.Logger, i uint64) bool {
// If i is the next number, return true and update current. // Fast path: i is the next expected counter. Split out so the function
// stays small and avoids paying for the slow paths' slog argument-build
// stack frame on every call. The bit read/test/write is inlined to
// touch the backing word once.
if i == b.current+1 { if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter pos := i & b.lengthMask
// The very first window can only be tracked as lost once we are on the 2nd window or greater word := pos >> 6
if b.bits[i%b.length] == false && i > b.length { mask := uint64(1) << (pos & 63)
w := b.bits[word]
if i > b.length && w&mask == 0 {
b.lostCounter.Inc(1) b.lostCounter.Inc(1)
} }
b.bits[i%b.length] = true b.bits[word] = w | mask
b.current = i b.current = i
return true return true
} }
return b.updateSlow(l, i)
}
// updateSlow handles jumps, in-window backfill, dupes, and out-of-window.
func (b *Bits) updateSlow(l *slog.Logger, i uint64) bool {
// If i is a jump, adjust the window, record lost, update current, and return true // If i is a jump, adjust the window, record lost, update current, and return true
if i > b.current { if i > b.current {
lost := int64(0) end := i
// Zero out the bits between the current and the new counter value, limited by the window size, if end > b.current+b.length {
// since the window is shifting end = b.current + b.length
for n := b.current + 1; n <= min(i, b.current+b.length); n++ { }
if b.bits[n%b.length] == false && n > b.length { count := end - b.current
startPos := (b.current + 1) & b.lengthMask
var lost int64
if b.current >= b.length {
// Steady state: every cleared slot is past warmup, so any unset
// bit we evict is a lost packet from the previous cycle.
wasSet := b.clearRange(startPos, count)
lost = int64(count) - int64(wasSet)
} else {
// Warmup (the very first window). Some cleared slots represent
// packets <= length where eviction is not "lost" in the usual
// sense. This branch is taken at most once per connection so we
// don't bother optimizing it.
for n := b.current + 1; n <= end; n++ {
if !b.get(n) && n > b.length {
lost++ lost++
} }
b.bits[n%b.length] = false }
b.clearRange(startPos, count)
} }
// Only record any skipped packets as a result of the window moving further than the window length // Anything past the new window can never be backfilled, so it's lost.
// Any loss within the new window will be accounted for in future calls if i > b.current+b.length {
lost += max(0, int64(i-b.current-b.length)) lost += int64(i - b.current - b.length)
}
b.lostCounter.Inc(lost) b.lostCounter.Inc(lost)
b.bits[i%b.length] = true b.set(i)
b.current = i b.current = i
return true return true
} }
// If i is within the current window but below the current counter, // If i is within the current window but below the current counter, check to see if it's a duplicate
// Check to see if it's a duplicate if b.strictlyWithinWindow(i) {
if i > b.current-b.length || i < b.length && b.current < b.length { pos := i & b.lengthMask
if b.current == i || b.bits[i%b.length] == true { word := pos >> 6
mask := uint64(1) << (pos & 63)
w := b.bits[word]
if b.current == i || w&mask != 0 {
if l.Enabled(context.Background(), slog.LevelDebug) { if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window", l.Debug("Receive window",
"accepted", false, "accepted", false,
@@ -104,7 +245,7 @@ func (b *Bits) Update(l *slog.Logger, i uint64) bool {
return false return false
} }
b.bits[i%b.length] = true b.bits[word] = w | mask
return true return true
} }
+276 -129
View File
@@ -7,61 +7,79 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
// snapshot returns the bitmap as a []bool of length b.length, for readable
// test assertions against the now-packed []uint64 storage.
func (b *Bits) snapshot() []bool {
out := make([]bool, b.length)
for i := uint64(0); i < b.length; i++ {
out[i] = b.get(i)
}
return out
}
func TestBitsRequiresPowerOfTwo(t *testing.T) {
assert.Panics(t, func() { NewBits(10) })
assert.Panics(t, func() { NewBits(0) })
assert.NotPanics(t, func() { NewBits(1) })
assert.NotPanics(t, func() { NewBits(16) })
assert.NotPanics(t, func() { NewBits(1024) })
assert.NotPanics(t, func() { NewBits(16384) })
}
func TestBits(t *testing.T) { func TestBits(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
b := NewBits(10) b := NewBits(16)
assert.EqualValues(t, 16, b.length)
// make sure it is the right size
assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work. // This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1)) assert.True(t, b.Check(l, 1))
assert.True(t, b.Update(l, 1)) assert.True(t, b.Update(l, 1))
assert.EqualValues(t, 1, b.current) assert.EqualValues(t, 1, b.current)
g := []bool{true, true, false, false, false, false, false, false, false, false} g := []bool{true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.snapshot())
// Receive two // Receive two
assert.True(t, b.Check(l, 2)) assert.True(t, b.Check(l, 2))
assert.True(t, b.Update(l, 2)) assert.True(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
g = []bool{true, true, true, false, false, false, false, false, false, false} g = []bool{true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.snapshot())
// Receive two again - it will fail // Receive two again - it will fail
assert.False(t, b.Check(l, 2)) assert.False(t, b.Check(l, 2))
assert.False(t, b.Update(l, 2)) assert.False(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element // Jump ahead to 25, which clears the window and sets slot 25%16 = 9.
assert.True(t, b.Check(l, 15)) assert.True(t, b.Check(l, 25))
assert.True(t, b.Update(l, 15)) assert.True(t, b.Update(l, 25))
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 25, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false} g = []bool{false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.snapshot())
// Mark 14, which is allowed because it is in the window // Mark 24, which is in window (current 25, length 16, window covers [10,25]).
assert.True(t, b.Check(l, 14)) assert.True(t, b.Check(l, 24))
assert.True(t, b.Update(l, 14)) assert.True(t, b.Update(l, 24))
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 25, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.snapshot())
// Mark 5, which is not allowed because it is not in the window // Mark 5, not allowed because 5 <= current-length (25-16=9).
assert.False(t, b.Check(l, 5)) assert.False(t, b.Check(l, 5))
assert.False(t, b.Update(l, 5)) assert.False(t, b.Update(l, 5))
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 25, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.snapshot())
// make sure we handle wrapping around once to the current position // Make sure we handle wrapping around once to the same slot. With
b = NewBits(10) // length=16, packets 1 and 17 share slot 1.
b = NewBits(16)
assert.True(t, b.Update(l, 1)) assert.True(t, b.Update(l, 1))
assert.True(t, b.Update(l, 11)) assert.True(t, b.Update(l, 17))
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits) assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}, b.snapshot())
// Walk through a few windows in order // Walk through a few windows in order
b = NewBits(10) b = NewBits(16)
for i := uint64(1); i <= 100; i++ { for i := uint64(1); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i)
@@ -72,24 +90,31 @@ func TestBits(t *testing.T) {
func TestBitsLargeJumps(t *testing.T) { func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
b := NewBits(10)
// length=16. Update(55) from current=0:
// warmup, per-bit loop sees no n>16 with unset bits (slot 0 was set by
// NewBits and gets re-evaluated when n=16; n=16 is not strictly > 16),
// so the loop contributes 0. The jump exceeds the window so we record
// 55 - 0 - 16 = 39 packets fell out the back.
b := NewBits(16)
b.lostCounter.Clear() b.lostCounter.Clear()
assert.True(t, b.Update(l, 55))
assert.Equal(t, int64(39), b.lostCounter.Count())
b = NewBits(10) // Update(100): clears 16 slots starting at slot 56%16=8. Only slot 7 (for
b.lostCounter.Clear() // packet 55) was set, so 16 - 1 = 15 evicted slots had unset bits.
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54 // Plus 100 - 55 - 16 = 29 packets fell past the window. Total 44.
assert.Equal(t, int64(45), b.lostCounter.Count()) assert.True(t, b.Update(l, 100))
assert.Equal(t, int64(39+44), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99 // Update(200): same shape: 16 - 1 = 15 evicted unset, plus 200 - 100 - 16 = 84 past window. Total 99.
assert.Equal(t, int64(89), b.lostCounter.Count()) assert.True(t, b.Update(l, 200))
assert.Equal(t, int64(39+44+99), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
} }
func TestBitsDupeCounter(t *testing.T) { func TestBitsDupeCounter(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
b := NewBits(10) b := NewBits(16)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
@@ -114,120 +139,117 @@ func TestBitsDupeCounter(t *testing.T) {
func TestBitsOutOfWindowCounter(t *testing.T) { func TestBitsOutOfWindowCounter(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
b := NewBits(10) b := NewBits(16)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
// Jump to 20 (warmup branch + 4 past-window packets).
assert.True(t, b.Update(l, 20)) assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.True(t, b.Update(l, 21)) // 9 single-step advances, each evicts a slot whose bit was cleared during
assert.True(t, b.Update(l, 22)) // the jump above and whose value was never seen, so each contributes 1
assert.True(t, b.Update(l, 23)) // to lostCounter.
assert.True(t, b.Update(l, 24)) for n := uint64(21); n <= 29; n++ {
assert.True(t, b.Update(l, 25)) assert.True(t, b.Update(l, n))
assert.True(t, b.Update(l, 26)) }
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
// 0 is below current-length (29-16=13) so it falls outside the window.
assert.False(t, b.Update(l, 0)) assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost // 4 from the Update(20) jump + 9 from 21..29.
assert.Equal(t, int64(13), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
} }
func TestBitsLostCounter(t *testing.T) { func TestBitsLostCounter(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
b := NewBits(10) b := NewBits(16)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 20)) // Walk 20..29 like the original, just with a bigger window. Same
assert.True(t, b.Update(l, 21)) // reasoning as TestBitsOutOfWindowCounter: 4 past-window from Update(20),
assert.True(t, b.Update(l, 22)) // then 9 more from the unit advances.
assert.True(t, b.Update(l, 23)) for n := uint64(20); n <= 29; n++ {
assert.True(t, b.Update(l, 24)) assert.True(t, b.Update(l, n))
assert.True(t, b.Update(l, 25)) }
assert.True(t, b.Update(l, 26)) assert.Equal(t, int64(13), b.lostCounter.Count())
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
b = NewBits(10) b = NewBits(16)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 9)) // Update(15) clears the warmup window (no lost), sets slot 15.
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(1), b.lostCounter.Count())
// Now let's fill in the window, should end up with 8 lost packets
assert.True(t, b.Update(l, 12))
assert.True(t, b.Update(l, 13))
assert.True(t, b.Update(l, 14))
assert.True(t, b.Update(l, 15)) assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Update(16): slot 0 was already set (NewBits seeded it), and 16 is not
// strictly > length, so nothing is recorded as lost.
assert.True(t, b.Update(l, 16)) assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Update(17): we jumped straight from 0 to 15, so slot 1 was cleared
// (and never re-set). 17 > 16 is past warmup, so packet 1 is recorded lost.
assert.True(t, b.Update(l, 17)) assert.True(t, b.Update(l, 17))
assert.True(t, b.Update(l, 18)) assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(8), b.lostCounter.Count())
// Jump ahead by a window size // Fill in 18..30 in single steps. Each i evicts slot i%16. Slots 2..14
assert.True(t, b.Update(l, 29)) // were all cleared during Update(15), and we never re-set any of them,
assert.Equal(t, int64(8), b.lostCounter.Count()) // so each i in 18..30 is a fresh lost packet — 13 more.
// Now lets walk ahead normally through the window, the missed packets should fill in for n := uint64(18); n <= 30; n++ {
assert.True(t, b.Update(l, 30)) assert.True(t, b.Update(l, n))
assert.True(t, b.Update(l, 31)) }
assert.True(t, b.Update(l, 32)) assert.Equal(t, int64(14), b.lostCounter.Count())
assert.True(t, b.Update(l, 33))
assert.True(t, b.Update(l, 34))
assert.True(t, b.Update(l, 35))
assert.True(t, b.Update(l, 36))
assert.True(t, b.Update(l, 37))
assert.True(t, b.Update(l, 38))
// 39 packets tracked, 22 seen, 17 lost
assert.Equal(t, int64(17), b.lostCounter.Count())
// Jump ahead by 2 windows, should have recording 1 full window missing // Jump ahead by exactly one window size.
assert.True(t, b.Update(l, 58)) assert.True(t, b.Update(l, 46))
assert.Equal(t, int64(27), b.lostCounter.Count()) // end = min(46, 30+16) = 46, count = 16, all slots cleared. Before the
// Now lets walk ahead normally through the window, the missed packets should fill in from this window // jump every slot 0..15 had been set (Update(15), (16), (17), 18..30),
assert.True(t, b.Update(l, 59)) // so wasSet=16 and 46 == current+length means no past-window slack:
assert.True(t, b.Update(l, 60)) // lost contribution = 0.
assert.True(t, b.Update(l, 61)) assert.Equal(t, int64(14), b.lostCounter.Count())
assert.True(t, b.Update(l, 62))
assert.True(t, b.Update(l, 63)) // Walk 47..55. The Update(46) jump cleared every slot, so only slot 14
assert.True(t, b.Update(l, 64)) // (for packet 46) is set when we start. Each subsequent unit step lands
assert.True(t, b.Update(l, 65)) // on a slot that was cleared and is past warmup, so it counts as lost.
assert.True(t, b.Update(l, 66)) // 9 more = 23.
assert.True(t, b.Update(l, 67)) for n := uint64(47); n <= 55; n++ {
// 68 packets tracked, 32 seen, 36 missed assert.True(t, b.Update(l, n))
assert.Equal(t, int64(36), b.lostCounter.Count()) }
assert.Equal(t, int64(23), b.lostCounter.Count())
// Jump ahead by two windows: clears the window plus past-window loss.
assert.True(t, b.Update(l, 87))
// current=55, length=16. end = min(87, 71) = 71. count=16, all slots
// cleared. Slots set before the clear are slots 14,15,0..7 (10 total).
// Lost from clear = 16 - 10 = 6. Past window: 87 - 55 - 16 = 16. +22.
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
} }
func TestBitsLostCounterIssue1(t *testing.T) { func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
b := NewBits(10) b := NewBits(16)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
// Receive 4, backfill 1, then 9, 2, 3, 5, 6, 7 (skip 8), 10, 11, 14.
// Then jump to 25 — slot 25%16=9 is being evicted, but it had been set
// (we received packet 9), so no spurious lost increment. The original
// regression was about double-counting a missing packet when its slot
// got cleared on a jump. With the jump path now using clearRange's
// word-level wasSet count, the same semantics hold.
assert.True(t, b.Update(l, 4)) assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1)) assert.True(t, b.Update(l, 1))
@@ -244,7 +266,7 @@ func TestBitsLostCounterIssue1(t *testing.T) {
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7)) assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8)) // Skip packet 8.
assert.True(t, b.Update(l, 10)) assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11)) assert.True(t, b.Update(l, 11))
@@ -252,9 +274,23 @@ func TestBitsLostCounterIssue1(t *testing.T) {
assert.True(t, b.Update(l, 14)) assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19)) // Jump to 25. With length=16, slot 25%16=9 corresponds to packet 9
// (which we DID receive), so its bit is set and no lost++ from that
// eviction. The trace below shows the only loss is packet 8.
assert.True(t, b.Update(l, 25))
// current was 14, i=25. end=min(25,30)=25. count=11. startPos=15.
// steady? current=14<16, so warmup branch: per-bit n=15..25, count those
// with !get(n) AND n>16. n=17..25 are >16. Among slots 17%16=1..25%16=9
// did we set slots 1..9 (packets 1..9)? Yes for all but slot 8 (packet 8
// was skipped). n=24 maps to slot 8 which is FALSE → lost++. All other
// n in 17..25 map to slots that are set. n=16 is not strictly > 16. So
// lost = 1.
assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(1), b.lostCounter.Count())
// Fill in 12, 13, 15, 16. Each is below current=25 (in-window). 16 must
// recheck slot 0 — it was set by NewBits and then cleared by the
// Update(25) jump, so 16 backfills cleanly.
assert.True(t, b.Update(l, 12)) assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13)) assert.True(t, b.Update(l, 13))
@@ -263,29 +299,140 @@ func TestBitsLostCounterIssue1(t *testing.T) {
assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16)) assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above // We missed packet 8 above and that loss is still recorded once, never
// double-counted, never zeroed.
assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
} }
func BenchmarkBits(b *testing.B) { // TestBitsWarmupOvershoot exercises the jump path's warmup arm with an
z := NewBits(10) // overshoot past one full window. NewBits leaves current=0 with only slot 0
for n := 0; n < b.N; n++ { // "set" by the marker. Jumping straight to length+k must (a) clear every
for i := range z.bits { // slot the jump straddles, (b) count only past-window slack (not the
z.bits[i] = true // in-window slots, which never had a "lost" tenant during warmup), and
} // (c) leave the cursor at the new counter so subsequent unit advances
for i := range z.bits { // count from steady state. The marker bit at slot 0 is irrelevant once
z.bits[i] = false // current >= length.
func TestBitsWarmupOvershoot(t *testing.T) {
l := test.NewLogger()
b := NewBits(16)
b.lostCounter.Clear()
// Jump from current=0 to i=20 (length=16, overshoot=4).
// Warmup arm: counts slots in [1..16] where bit unset and n>length.
// Only n=16 was unset and >length: but slot 16%16=0 is the marker,
// so b.get(16) reads bits[0]=1 and skips. Result: 0 lost from the loop.
// Past-window: i - current - length = 20 - 0 - 16 = 4 lost.
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(4), b.lostCounter.Count())
assert.Equal(t, uint64(20), b.current)
// Steady state now (current=20 >= length=16). Unit advance to 21
// stomps slot 21%16=5, which was cleared by the jump and not reset,
// so this is +1 lost.
assert.True(t, b.Update(l, 21))
assert.Equal(t, int64(5), b.lostCounter.Count())
} }
// TestBitsCheckAcrossWarmupBoundary pins the underflow trick in Check's
// in-window clause. While in warmup, b.current-b.length underflows uint64
// to a huge value so the first OR-clause is always false; the second
// clause (i < length && current < length) carries the in-window check.
// Once current >= length the regimes flip cleanly.
func TestBitsCheckAcrossWarmupBoundary(t *testing.T) {
l := test.NewLogger()
b := NewBits(16)
// Warmup: current=0. Check(0) must read the marker (set) and return false.
assert.False(t, b.Check(l, 0), "marker slot should look already-received")
// Warmup: any 0 < i < length is in-window and unset → accepted.
for i := uint64(1); i < 16; i++ {
assert.True(t, b.Check(l, i), "warmup in-window i=%d should be accepted", i)
}
// Warmup: i >= length but > current is "next number" so accepted.
assert.True(t, b.Check(l, 16))
assert.True(t, b.Check(l, 1_000_000))
// Cross into steady state.
assert.True(t, b.Update(l, 100))
// Now current=100, length=16. In-window range is [85..100].
// 84 is just outside: the underflow clause activates; 84 > 100-16=84 is false.
// And the warmup clause is false (current >= length). So out of window.
assert.False(t, b.Check(l, 84))
// 85 sits at the boundary. 85 > 84 is true → in window, unset → accept.
assert.True(t, b.Check(l, 85))
// 100 is current itself; not strictly greater, in-window, but already set.
assert.False(t, b.Check(l, 100))
// Way out: clearly out of window.
assert.False(t, b.Check(l, 50))
}
// TestBitsMarkerInvariant verifies the seeded bits[0]=1 marker behaves
// correctly across warmup and beyond. Update should never clear the marker
// during warmup (clearRange skips position 0 when startPos=1), and once
// current >= length the marker is no longer consulted by Check/Update on
// the live path — but it must still report counter 0 as a duplicate while
// we are in warmup.
func TestBitsMarkerInvariant(t *testing.T) {
l := test.NewLogger()
b := NewBits(8)
// Counter 0 is the seeded marker; Check sees it as already received.
assert.False(t, b.Check(l, 0))
// Update(0) at current=0 hits the duplicate branch.
b.dupeCounter.Clear()
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.dupeCounter.Count())
// Walk forward through warmup; the marker must remain set.
for n := uint64(1); n <= 7; n++ {
assert.True(t, b.Update(l, n))
}
// Position 0 (the marker) should still read as set because we never
// cleared it; Update(0) still looks like a duplicate.
assert.False(t, b.Check(l, 0))
// Cross into steady state with a unit advance to 8: pos=0, evicts the
// marker bit. The lost-counter guard (i > b.length) is false (8 == 8),
// so this advance does NOT charge a lost packet — exactly what the
// marker is there to prevent.
b.lostCounter.Clear()
assert.True(t, b.Update(l, 8))
assert.Equal(t, int64(0), b.lostCounter.Count())
// The slot at pos 0 is now occupied by counter 8.
assert.False(t, b.Check(l, 8))
}
// BenchmarkBitsUpdateInOrder is the steady-state hot path: each call is
// i == current+1.
func BenchmarkBitsUpdateInOrder(b *testing.B) {
l := test.NewLogger()
z := NewBits(16384)
for n := 0; n < b.N; n++ {
z.Update(l, uint64(n)+1)
}
}
// BenchmarkBitsUpdateReorder simulates light reorder within the window:
// every other packet arrives one slot behind its predecessor (forces the
// in-window backfill branch).
func BenchmarkBitsUpdateReorder(b *testing.B) {
l := test.NewLogger()
z := NewBits(16384)
for n := 0; n < b.N; n++ {
base := uint64(n) * 2
z.Update(l, base+2)
z.Update(l, base+1)
}
}
// BenchmarkBitsUpdateLargeJumps stresses the clearRange word-level path.
func BenchmarkBitsUpdateLargeJumps(b *testing.B) {
l := test.NewLogger()
z := NewBits(16384)
for n := 0; n < b.N; n++ {
z.Update(l, uint64(n+1)*1000)
} }
} }
+4
View File
@@ -217,6 +217,10 @@ func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp
return nil, err return nil, err
} }
if signer.Certificate.Curve() != c.Curve() {
return nil, ErrCurveMismatch
}
if signer.Certificate.Expired(now) { if signer.Certificate.Expired(now) {
return nil, ErrRootExpired return nil, ErrRootExpired
} }
+28
View File
@@ -654,3 +654,31 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err) require.NoError(t, err)
} }
func TestCertificateV2_CurveMismatch(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.0.0.1/24")
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1}, nil, []string{"test"})
fp, _ := c.Fingerprint()
_, err = caPool.verify(c, time.Now(), fp, c.Issuer())
require.NoError(t, err)
//
c2 := c.(*certificateV2)
c2.curve = Curve_CURVE25519
fp, _ = c.Fingerprint()
_, err = caPool.verify(c, time.Now(), fp, c.Issuer())
require.Error(t, err)
}
+3
View File
@@ -112,6 +112,9 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
} }
switch c.details.curve { switch c.details.curve {
case Curve_CURVE25519: case Curve_CURVE25519:
if len(key) != ed25519.PublicKeySize {
return false //avoids a panic internal to ed25519
}
return ed25519.Verify(key, b, c.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
+3
View File
@@ -151,6 +151,9 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
switch c.curve { switch c.curve {
case Curve_CURVE25519: case Curve_CURVE25519:
if len(key) != ed25519.PublicKeySize {
return false //avoids a panic internal to ed25519
}
return ed25519.Verify(key, b, c.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
+1
View File
@@ -22,6 +22,7 @@ var (
ErrCaNotFound = errors.New("could not find ca for the certificate") ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized") ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrCertPubkeyPresent = errors.New("certificate has unexpected pubkey present") ErrCertPubkeyPresent = errors.New("certificate has unexpected pubkey present")
ErrCurveMismatch = errors.New("certificate curve does not match CA")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
+10 -4
View File
@@ -13,6 +13,12 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
// testCertNow is the reference "now" used to derive default before/after times
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
// signed with default times can never expire after its CA on a rounding race.
var testCertNow = time.Now().Round(time.Second)
// NewTestCaCert will create a new ca certificate // NewTestCaCert will create a new ca certificate
func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
var err error var err error
@@ -34,10 +40,10 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
} }
if before.IsZero() { if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second) before = testCertNow.Add(time.Second * -60)
} }
if after.IsZero() { if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second) after = testCertNow.Add(time.Second * 60)
} }
t := &TBSCertificate{ t := &TBSCertificate{
@@ -70,11 +76,11 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
// Expiry times are defaulted if you do not pass them in // Expiry times are defaulted if you do not pass them in
func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
if before.IsZero() { if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second) before = testCertNow.Add(time.Second * -60)
} }
if after.IsZero() { if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second) after = testCertNow.Add(time.Second * 60)
} }
if len(networks) == 0 { if len(networks) == 0 {
+10 -4
View File
@@ -14,6 +14,12 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
// testCertNow is the reference "now" used to derive default before/after times
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
// signed with default times can never expire after its CA on a rounding race.
var testCertNow = time.Now().Round(time.Second)
// NewTestCaCert will create a new ca certificate // NewTestCaCert will create a new ca certificate
func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
var err error var err error
@@ -35,10 +41,10 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
} }
if before.IsZero() { if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second) before = testCertNow.Add(time.Second * -60)
} }
if after.IsZero() { if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second) after = testCertNow.Add(time.Second * 60)
} }
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
@@ -71,11 +77,11 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
// Expiry times are defaulted if you do not pass them in // Expiry times are defaulted if you do not pass them in
func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() { if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second) before = testCertNow.Add(time.Second * -60)
} }
if after.IsZero() { if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second) after = testCertNow.Add(time.Second * 60)
} }
var pub, priv []byte var pub, priv []byte
+30 -5
View File
@@ -97,6 +97,19 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
if err = mustFlagString("out-key", cf.outKeyPath); err != nil { if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
return err return err
} }
} else {
// out-key is meaningless under PKCS#11 because the private key never
// leaves the HSM; reject it so we never silently accept or claim a
// stdout slot for it.
outKeySet := false
cf.set.Visit(func(f *flag.Flag) {
if f.Name == "out-key" {
outKeySet = true
}
})
if outKeySet {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
} }
if err := mustFlagString("out-crt", cf.outCertPath); err != nil { if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
return err return err
@@ -171,12 +184,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
} }
} }
var claims ioClaims
if err := reserveOutputs(&claims,
"out-key", *cf.outKeyPath,
"out-crt", *cf.outCertPath,
"out-qr", *cf.outQRPath,
); err != nil {
return err
}
var passphrase []byte var passphrase []byte
if !isP11 && *cf.encryption { if !isP11 && *cf.encryption {
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE")) passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 { if len(passphrase) == 0 {
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: ")) errOut.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword() passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal { if err == ErrNoTerminal {
@@ -261,15 +283,17 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
Curve: curve, Curve: curve,
} }
if !isP11 { if !isP11 && !isStdio(*cf.outKeyPath) {
if _, err := os.Stat(*cf.outKeyPath); err == nil { if _, err := os.Stat(*cf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
} }
} }
if !isStdio(*cf.outCertPath) {
if _, err := os.Stat(*cf.outCertPath); err == nil { if _, err := os.Stat(*cf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
} }
}
var c cert.Certificate var c cert.Certificate
var b []byte var b []byte
@@ -294,7 +318,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv) b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
} }
err = os.WriteFile(*cf.outKeyPath, b, 0600) err = writeOutput(*cf.outKeyPath, b, 0600, out)
if err != nil { if err != nil {
return fmt.Errorf("error while writing out-key: %s", err) return fmt.Errorf("error while writing out-key: %s", err)
} }
@@ -305,7 +329,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("error while marshalling certificate: %s", err) return fmt.Errorf("error while marshalling certificate: %s", err)
} }
err = os.WriteFile(*cf.outCertPath, b, 0600) err = writeOutput(*cf.outCertPath, b, 0600, out)
if err != nil { if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err) return fmt.Errorf("error while writing out-crt: %s", err)
} }
@@ -316,7 +340,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("error while generating qr code: %s", err) return fmt.Errorf("error while generating qr code: %s", err)
} }
err = os.WriteFile(*cf.outQRPath, b, 0600) err = writeOutput(*cf.outQRPath, b, 0600, out)
if err != nil { if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err) return fmt.Errorf("error while writing out-qr: %s", err)
} }
@@ -332,6 +356,7 @@ func caSummary() string {
func caHelp(out io.Writer) { func caHelp(out io.Writer) {
cf := newCaFlags() cf := newCaFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n")) out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n"))
out.Write([]byte(stdioHelpText))
cf.set.SetOutput(out) cf.set.SetOutput(out)
cf.set.PrintDefaults() cf.set.PrintDefaults()
} }
+72 -7
View File
@@ -27,6 +27,7 @@ func Test_caHelp(t *testing.T) {
assert.Equal( assert.Equal(
t, t,
"Usage of "+os.Args[0]+" ca <flags>: create a self signed certificate authority\n"+ "Usage of "+os.Args[0]+" ca <flags>: create a self signed certificate authority\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -argon-iterations uint\n"+ " -argon-iterations uint\n"+
" \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+ " \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+
" -argon-memory uint\n"+ " -argon-memory uint\n"+
@@ -84,7 +85,7 @@ func Test_ca(t *testing.T) {
err: nil, err: nil,
} }
pwPromptOb := "Enter passphrase: " pwPromptEB := "Enter passphrase: "
// required args // required args
assertHelpError(t, ca( assertHelpError(t, ca(
@@ -168,8 +169,8 @@ func Test_ca(t *testing.T) {
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.NoError(t, ca(args, ob, eb, testpw)) require.NoError(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Equal(t, pwPromptEB, eb.String())
// test encrypted key with passphrase environment variable // test encrypted key with passphrase environment variable
os.Remove(keyF.Name()) os.Remove(keyF.Name())
@@ -207,8 +208,8 @@ func Test_ca(t *testing.T) {
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.Error(t, ca(args, ob, eb, errpw)) require.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Equal(t, pwPromptEB, eb.String())
// test when user fails to enter a password // test when user fails to enter a password
os.Remove(keyF.Name()) os.Remove(keyF.Name())
@@ -217,8 +218,8 @@ func Test_ca(t *testing.T) {
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Equal(t, strings.Repeat(pwPromptEB, 5), eb.String()) // prompts 5 times before giving up
// create valid cert/key for overwrite tests // create valid cert/key for overwrite tests
os.Remove(keyF.Name()) os.Remove(keyF.Name())
@@ -247,3 +248,67 @@ func Test_ca(t *testing.T) {
os.Remove(keyF.Name()) os.Remove(keyF.Name())
} }
func Test_ca_stdio(t *testing.T) {
nopw := &StubPasswordReader{}
keyF, err := os.CreateTemp("", "ca.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
crtF, err := os.CreateTemp("", "ca.crt")
require.NoError(t, err)
os.Remove(crtF.Name())
defer os.Remove(crtF.Name())
// out-crt on stdout, out-key on disk
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", keyF.Name()}, ob, eb, nopw))
assert.Empty(t, eb.String())
c, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.True(t, c.IsCA())
assert.Equal(t, "test-ca", c.Name())
// out-key on stdout, out-crt on disk
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", crtF.Name(), "-out-key", "-"}, ob, eb, nopw))
assert.Empty(t, eb.String())
_, _, curve, err := cert.UnmarshalSigningPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
// dual stdout is rejected up front
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
require.EqualError(t,
ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", "-"}, ob, eb, nopw),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
// an output conflict combined with -encrypt must error BEFORE prompting
// for a passphrase; pr would record any read attempt
tracker := &trackingPasswordReader{}
ob.Reset()
eb.Reset()
require.EqualError(t,
ca([]string{"-name", "test-ca", "-duration", "1h", "-encrypt", "-out-crt", "-", "-out-key", "-"}, ob, eb, tracker),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Zero(t, tracker.calls, "passphrase prompt should not have been called")
}
type trackingPasswordReader struct {
calls int
}
func (pr *trackingPasswordReader) ReadPassword() ([]byte, error) {
pr.calls++
return []byte(""), nil
}
+13 -2
View File
@@ -42,6 +42,8 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
if err = mustFlagString("out-key", cf.outKeyPath); err != nil { if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
return err return err
} }
} else if *cf.outKeyPath != "" {
return newHelpErrorf("cannot set -out-key with -pkcs11")
} }
if err = mustFlagString("out-pub", cf.outPubPath); err != nil { if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
return err return err
@@ -69,6 +71,14 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
} }
} }
var claims ioClaims
if err := reserveOutputs(&claims,
"out-key", *cf.outKeyPath,
"out-pub", *cf.outPubPath,
); err != nil {
return err
}
if isP11 { if isP11 {
p11Client, err := pkclient.FromUrl(*cf.p11url) p11Client, err := pkclient.FromUrl(*cf.p11url)
if err != nil { if err != nil {
@@ -82,12 +92,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while getting public key: %w", err) return fmt.Errorf("error while getting public key: %w", err)
} }
} else { } else {
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) err = writeOutput(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
if err != nil { if err != nil {
return fmt.Errorf("error while writing out-key: %s", err) return fmt.Errorf("error while writing out-key: %s", err)
} }
} }
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600) err = writeOutput(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600, out)
if err != nil { if err != nil {
return fmt.Errorf("error while writing out-pub: %s", err) return fmt.Errorf("error while writing out-pub: %s", err)
} }
@@ -102,6 +112,7 @@ func keygenSummary() string {
func keygenHelp(out io.Writer) { func keygenHelp(out io.Writer) {
cf := newKeygenFlags() cf := newKeygenFlags()
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
_, _ = out.Write([]byte(stdioHelpText))
cf.set.SetOutput(out) cf.set.SetOutput(out)
cf.set.PrintDefaults() cf.set.PrintDefaults()
} }
+41
View File
@@ -20,6 +20,7 @@ func Test_keygenHelp(t *testing.T) {
assert.Equal( assert.Equal(
t, t,
"Usage of "+os.Args[0]+" keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+ "Usage of "+os.Args[0]+" keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -curve string\n"+ " -curve string\n"+
" \tECDH Curve (25519, P256) (default \"25519\")\n"+ " \tECDH Curve (25519, P256) (default \"25519\")\n"+
" -out-key string\n"+ " -out-key string\n"+
@@ -93,3 +94,43 @@ func Test_keygen(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, lPub, 32) assert.Len(t, lPub, 32)
} }
func Test_keygen_stdio(t *testing.T) {
keyF, err := os.CreateTemp("", "test.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
pubF, err := os.CreateTemp("", "test.pub")
require.NoError(t, err)
os.Remove(pubF.Name())
defer os.Remove(pubF.Name())
// out-pub on stdout, out-key on disk
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
require.NoError(t, keygen([]string{"-out-pub", "-", "-out-key", keyF.Name()}, ob, eb))
assert.Empty(t, eb.String())
lPub, _, curve, err := cert.UnmarshalPublicKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, lPub, 32)
// out-key on stdout, out-pub on disk
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
require.NoError(t, keygen([]string{"-out-pub", pubF.Name(), "-out-key", "-"}, ob, eb))
assert.Empty(t, eb.String())
lKey, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, lKey, 32)
// both on stdout is a conflict caught up front
ob.Reset()
eb.Reset()
require.EqualError(t, keygen([]string{"-out-pub", "-", "-out-key", "-"}, ob, eb),
`-out-key and -out-pub both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
}
+3 -1
View File
@@ -22,7 +22,9 @@ func (pr StdinPasswordReader) ReadPassword() ([]byte, error) {
} }
password, err := term.ReadPassword(int(os.Stdin.Fd())) password, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println() // Terminal echo is off while reading, so the user's Enter key does not
// produce a visible newline. Emit one on stderr to match the prompt.
fmt.Fprintln(os.Stderr)
return password, err return password, err
} }
+18 -3
View File
@@ -40,11 +40,23 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return err return err
} }
rawCert, err := os.ReadFile(*pf.path) var claims ioClaims
if err := reserveInputs(&claims, "path", *pf.path); err != nil {
return err
}
if err := reserveOutputs(&claims, "out-qr", *pf.outQRPath); err != nil {
return err
}
rawCert, err := readInput("path", *pf.path, &claims)
if err != nil { if err != nil {
return fmt.Errorf("unable to read cert; %s", err) return fmt.Errorf("unable to read cert; %s", err)
} }
// When the QR is going to stdout, suppress the human-readable text/json
// output so the binary stream is not contaminated.
qrToStdout := isStdio(*pf.outQRPath)
var c cert.Certificate var c cert.Certificate
var qrBytes []byte var qrBytes []byte
part := 0 part := 0
@@ -57,12 +69,14 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while unmarshaling cert: %s", err) return fmt.Errorf("error while unmarshaling cert: %s", err)
} }
if !qrToStdout {
if *pf.json { if *pf.json {
jsonCerts = append(jsonCerts, c) jsonCerts = append(jsonCerts, c)
} else { } else {
_, _ = out.Write([]byte(c.String())) _, _ = out.Write([]byte(c.String()))
_, _ = out.Write([]byte("\n")) _, _ = out.Write([]byte("\n"))
} }
}
if *pf.outQRPath != "" { if *pf.outQRPath != "" {
b, err := c.MarshalPEM() b, err := c.MarshalPEM()
@@ -79,7 +93,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
part++ part++
} }
if *pf.json { if *pf.json && !qrToStdout {
b, _ := json.Marshal(jsonCerts) b, _ := json.Marshal(jsonCerts)
_, _ = out.Write(b) _, _ = out.Write(b)
_, _ = out.Write([]byte("\n")) _, _ = out.Write([]byte("\n"))
@@ -91,7 +105,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while generating qr code: %s", err) return fmt.Errorf("error while generating qr code: %s", err)
} }
err = os.WriteFile(*pf.outQRPath, b, 0600) err = writeOutput(*pf.outQRPath, b, 0600, out)
if err != nil { if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err) return fmt.Errorf("error while writing out-qr: %s", err)
} }
@@ -107,6 +121,7 @@ func printSummary() string {
func printHelp(out io.Writer) { func printHelp(out io.Writer) {
pf := newPrintFlags() pf := newPrintFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n")) out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n"))
out.Write([]byte(stdioHelpText))
pf.set.SetOutput(out) pf.set.SetOutput(out)
pf.set.PrintDefaults() pf.set.PrintDefaults()
} }
+39
View File
@@ -25,6 +25,7 @@ func Test_printHelp(t *testing.T) {
assert.Equal( assert.Equal(
t, t,
"Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+ "Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -json\n"+ " -json\n"+
" \tOptional: outputs certificates in json format\n"+ " \tOptional: outputs certificates in json format\n"+
" -out-qr string\n"+ " -out-qr string\n"+
@@ -178,6 +179,44 @@ func Test_printCert(t *testing.T) {
ob.String(), ob.String(),
) )
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// read cert from stdin
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-json", "-path", "-"}, ob, eb)
require.NoError(t, err)
assert.Equal(
t,
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(),
)
assert.Empty(t, eb.String())
// -out-qr - sends only the PNG to stdout, suppressing the cert dump
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-path", "-", "-out-qr", "-"}, ob, eb)
require.NoError(t, err)
assert.Empty(t, eb.String())
stdout := ob.Bytes()
require.NotEmpty(t, stdout)
// PNG magic, no PEM/JSON noise prepended
assert.Equal(t, []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}, stdout[:8])
assert.NotContains(t, string(stdout), "NebulaCertificate")
assert.NotContains(t, string(stdout), `"details"`)
// json + out-qr - still suppresses json
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-json", "-path", "-", "-out-qr", "-"}, ob, eb)
require.NoError(t, err)
assert.Empty(t, eb.String())
assert.Equal(t, []byte{0x89, 'P', 'N', 'G'}, ob.Bytes()[:4])
assert.NotContains(t, ob.String(), `"details"`)
} }
// NewTestCaCert will generate a CA cert // NewTestCaCert will generate a CA cert
+38 -16
View File
@@ -85,6 +85,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set both -in-pub and -out-key") return newHelpErrorf("cannot set both -in-pub and -out-key")
} }
if isP11 && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
var v4Networks []netip.Prefix var v4Networks []netip.Prefix
var v6Networks []netip.Prefix var v6Networks []netip.Prefix
@@ -102,13 +105,35 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
} }
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
if *sf.outCertPath == "" {
*sf.outCertPath = *sf.name + ".crt"
}
var claims ioClaims
if err := reserveInputs(&claims,
"ca-key", *sf.caKeyPath,
"ca-crt", *sf.caCertPath,
"in-pub", *sf.inPubPath,
); err != nil {
return err
}
if err := reserveOutputs(&claims,
"out-key", *sf.outKeyPath,
"out-crt", *sf.outCertPath,
"out-qr", *sf.outQRPath,
); err != nil {
return err
}
var curve cert.Curve var curve cert.Curve
var caKey []byte var caKey []byte
if !isP11 { if !isP11 {
var rawCAKey []byte var rawCAKey []byte
rawCAKey, err := os.ReadFile(*sf.caKeyPath) rawCAKey, err = readInput("ca-key", *sf.caKeyPath, &claims)
if err != nil { if err != nil {
return fmt.Errorf("error while reading ca-key: %s", err) return fmt.Errorf("error while reading ca-key: %s", err)
} }
@@ -121,7 +146,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if len(passphrase) == 0 { if len(passphrase) == 0 {
// ask for a passphrase until we get one // ask for a passphrase until we get one
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: ")) errOut.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword() passphrase, err = pr.ReadPassword()
if errors.Is(err, ErrNoTerminal) { if errors.Is(err, ErrNoTerminal) {
@@ -147,7 +172,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
} }
rawCACert, err := os.ReadFile(*sf.caCertPath) rawCACert, err := readInput("ca-crt", *sf.caCertPath, &claims)
if err != nil { if err != nil {
return fmt.Errorf("error while reading ca-crt: %s", err) return fmt.Errorf("error while reading ca-crt: %s", err)
} }
@@ -245,7 +270,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if *sf.inPubPath != "" { if *sf.inPubPath != "" {
var pubCurve cert.Curve var pubCurve cert.Curve
rawPub, err := os.ReadFile(*sf.inPubPath) rawPub, err := readInput("in-pub", *sf.inPubPath, &claims)
if err != nil { if err != nil {
return fmt.Errorf("error while reading in-pub: %s", err) return fmt.Errorf("error while reading in-pub: %s", err)
} }
@@ -266,17 +291,11 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
pub, rawPriv = newKeypair(curve) pub, rawPriv = newKeypair(curve)
} }
if *sf.outKeyPath == "" { if !isStdio(*sf.outCertPath) {
*sf.outKeyPath = *sf.name + ".key"
}
if *sf.outCertPath == "" {
*sf.outCertPath = *sf.name + ".crt"
}
if _, err := os.Stat(*sf.outCertPath); err == nil { if _, err := os.Stat(*sf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
} }
}
var crts []cert.Certificate var crts []cert.Certificate
@@ -360,11 +379,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
if !isP11 && *sf.inPubPath == "" { if !isP11 && *sf.inPubPath == "" {
if !isStdio(*sf.outKeyPath) {
if _, err := os.Stat(*sf.outKeyPath); err == nil { if _, err := os.Stat(*sf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
} }
}
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) err = writeOutput(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
if err != nil { if err != nil {
return fmt.Errorf("error while writing out-key: %s", err) return fmt.Errorf("error while writing out-key: %s", err)
} }
@@ -379,7 +400,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
b = append(b, sb...) b = append(b, sb...)
} }
err = os.WriteFile(*sf.outCertPath, b, 0600) err = writeOutput(*sf.outCertPath, b, 0600, out)
if err != nil { if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err) return fmt.Errorf("error while writing out-crt: %s", err)
} }
@@ -390,7 +411,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("error while generating qr code: %s", err) return fmt.Errorf("error while generating qr code: %s", err)
} }
err = os.WriteFile(*sf.outQRPath, b, 0600) err = writeOutput(*sf.outQRPath, b, 0600, out)
if err != nil { if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err) return fmt.Errorf("error while writing out-qr: %s", err)
} }
@@ -440,6 +461,7 @@ func signSummary() string {
func signHelp(out io.Writer) { func signHelp(out io.Writer) {
sf := newSignFlags() sf := newSignFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n")) out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n"))
out.Write([]byte(stdioHelpText))
sf.set.SetOutput(out) sf.set.SetOutput(out)
sf.set.PrintDefaults() sf.set.PrintDefaults()
} }
+112 -8
View File
@@ -27,6 +27,7 @@ func Test_signHelp(t *testing.T) {
assert.Equal( assert.Equal(
t, t,
"Usage of "+os.Args[0]+" sign <flags>: create and sign a certificate\n"+ "Usage of "+os.Args[0]+" sign <flags>: create and sign a certificate\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -ca-crt string\n"+ " -ca-crt string\n"+
" \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+ " \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+
" -ca-key string\n"+ " -ca-key string\n"+
@@ -376,15 +377,18 @@ func Test_signCert(t *testing.T) {
// test with the proper password // test with the proper password
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.NoError(t, signCert(args, ob, eb, testpw)) require.NoError(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Equal(t, "Enter passphrase: ", eb.String())
// test with the proper password in the environment // test with the proper password in the environment
os.Remove(crtF.Name()) os.Remove(crtF.Name())
os.Remove(keyF.Name()) os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase)) os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
ob.Reset()
eb.Reset()
require.NoError(t, signCert(args, ob, eb, testpw)) require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "") os.Setenv("NEBULA_CA_PASSPHRASE", "")
@@ -395,8 +399,8 @@ func Test_signCert(t *testing.T) {
testpw.password = []byte("invalid password") testpw.password = []byte("invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, testpw)) require.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Equal(t, "Enter passphrase: ", eb.String())
// test with the wrong password in environment // test with the wrong password in environment
ob.Reset() ob.Reset()
@@ -416,8 +420,8 @@ func Test_signCert(t *testing.T) {
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, nopw)) require.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these // normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", eb.String())
// test an error condition // test an error condition
ob.Reset() ob.Reset()
@@ -425,6 +429,106 @@ func Test_signCert(t *testing.T) {
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, errpw)) require.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Equal(t, "Enter passphrase: ", eb.String())
}
func Test_signCert_stdio(t *testing.T) {
nopw := &StubPasswordReader{
password: []byte(""),
err: nil,
}
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
rawCAKey := cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)
ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
rawCACrt, _ := ca.MarshalPEM()
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
require.NoError(t, err)
defer os.Remove(caCrtF.Name())
caCrtF.Write(rawCACrt)
caKeyF, err := os.CreateTemp("", "sign-cert.key")
require.NoError(t, err)
defer os.Remove(caKeyF.Name())
caKeyF.Write(rawCAKey)
keyF, err := os.CreateTemp("", "sign.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
// ca-key on stdin, cert to stdout
withStdin(t, bytes.NewReader(rawCAKey))
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
args := []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", keyF.Name(), "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
lCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, "stdin-test", lCrt.Name())
assert.True(t, lCrt.CheckSignature(caPub))
// two flags reading from stdin should error before any read attempt;
// otherwise an interactive shell would hang on io.ReadAll
stdinIn := bytes.NewReader(rawCAKey)
withStdin(t, stdinIn)
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", "-", "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw),
`-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
assert.Equal(t, len(rawCAKey), stdinIn.Len(), "stdin should be untouched when conflict is caught up front")
// two flags writing to stdout should error before any output is written
// AND before stdin is consumed
stdinR := bytes.NewReader(rawCAKey)
withStdin(t, stdinR)
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", "-", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
// stdin should be untouched because the conflict was caught up front
assert.Equal(t, len(rawCAKey), stdinR.Len())
// out-key on stdout, cert on disk
keyF2, err := os.CreateTemp("", "sign.key")
require.NoError(t, err)
os.Remove(keyF2.Name())
defer os.Remove(keyF2.Name())
crtF, err := os.CreateTemp("", "sign.crt")
require.NoError(t, err)
os.Remove(crtF.Name())
defer os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", "-", "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
_, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
// in-pub on stdin (caller already has a keypair, only the cert is generated)
inPub, _ := x25519Keypair()
rawInPub := cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)
withStdin(t, bytes.NewReader(rawInPub))
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "in-pub-test", "-ip", "1.1.1.1/24", "-in-pub", "-", "-out-crt", "-", "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
stdinCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, "in-pub-test", stdinCrt.Name())
assert.Equal(t, inPub, stdinCrt.PublicKey())
} }
+117
View File
@@ -0,0 +1,117 @@
package main
import (
"fmt"
"io"
"os"
)
// stdioPath is the special path value that selects stdin (for inputs) or
// stdout (for outputs) instead of a file on disk.
const stdioPath = "-"
// stdioHelpText is rendered just under the Usage line of each subcommand
// help so the - convention is documented once instead of on every flag.
const stdioHelpText = " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"
// stdinReader is the source used when an input flag is set to "-".
// It is a package level var so tests can swap in a deterministic reader.
// Tests that mutate stdinReader cannot run with t.Parallel().
var stdinReader io.Reader = os.Stdin
// ioClaims tracks which flags have claimed stdin and stdout during a single
// command invocation so we can refuse a second flag asking for the same
// stream.
type ioClaims struct {
in string
out string
}
func (c *ioClaims) claimIn(flagName string) error {
if c.in != "" && c.in != flagName {
return fmt.Errorf("-%s and -%s both set to %q, only one input may read from stdin", c.in, flagName, stdioPath)
}
c.in = flagName
return nil
}
func (c *ioClaims) claimOut(flagName string) error {
if c.out != "" && c.out != flagName {
return fmt.Errorf("-%s and -%s both set to %q, only one output may write to stdout", c.out, flagName, stdioPath)
}
c.out = flagName
return nil
}
// reserveInputs walks alternating (flagName, path) pairs and claims stdin
// for any path equal to stdioPath. It must be called before any input is
// read so a conflict can be reported immediately instead of blocking on
// io.ReadAll while waiting for input that will never arrive.
func reserveInputs(claims *ioClaims, pairs ...string) error {
return reserveStdio(claims, "reserveInputs", (*ioClaims).claimIn, pairs)
}
// reserveOutputs walks alternating (flagName, path) pairs and claims stdout
// for any path equal to stdioPath. It must be called before any output is
// written so a conflict cannot leave one stream half written before the
// second flag fails.
func reserveOutputs(claims *ioClaims, pairs ...string) error {
return reserveStdio(claims, "reserveOutputs", (*ioClaims).claimOut, pairs)
}
func reserveStdio(claims *ioClaims, who string, claim func(*ioClaims, string) error, pairs []string) error {
if len(pairs)%2 != 0 {
panic(who + " requires alternating name, path pairs")
}
for i := 0; i < len(pairs); i += 2 {
name, path := pairs[i], pairs[i+1]
if path != stdioPath {
continue
}
if err := claim(claims, name); err != nil {
return err
}
}
return nil
}
// readInput returns the bytes referenced by path, reading from stdin when
// path is stdioPath.
func readInput(flagName, path string, claims *ioClaims) ([]byte, error) {
if path == stdioPath {
if err := claims.claimIn(flagName); err != nil {
return nil, err
}
return io.ReadAll(stdinReader)
}
return os.ReadFile(path)
}
// openInput returns a reader for path. When path is stdioPath the returned
// reader wraps stdin and Close is a no-op.
func openInput(flagName, path string, claims *ioClaims) (io.ReadCloser, error) {
if path == stdioPath {
if err := claims.claimIn(flagName); err != nil {
return nil, err
}
return io.NopCloser(stdinReader), nil
}
return os.Open(path)
}
// writeOutput writes data to path, or to stdout when path is stdioPath. perm
// is only used for file output. The caller must have already claimed stdout
// via reserveOutputs before invoking with stdioPath.
func writeOutput(path string, data []byte, perm os.FileMode, stdout io.Writer) error {
if path == stdioPath {
_, err := stdout.Write(data)
return err
}
return os.WriteFile(path, data, perm)
}
// isStdio reports whether path is the stdio sentinel and so should skip
// existence checks like "refuse to overwrite".
func isStdio(path string) bool {
return path == stdioPath
}
+167
View File
@@ -0,0 +1,167 @@
package main
import (
"bytes"
"io"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// withStdin temporarily replaces stdinReader for the duration of t.
func withStdin(t *testing.T, r io.Reader) {
t.Helper()
prev := stdinReader
stdinReader = r
t.Cleanup(func() { stdinReader = prev })
}
func Test_readInput_stdin(t *testing.T) {
withStdin(t, bytes.NewBufferString("hello"))
var claims ioClaims
got, err := readInput("path", "-", &claims)
require.NoError(t, err)
assert.Equal(t, []byte("hello"), got)
assert.Equal(t, "path", claims.in)
}
func Test_readInput_file(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
require.NoError(t, os.WriteFile(p, []byte("file"), 0600))
var claims ioClaims
got, err := readInput("path", p, &claims)
require.NoError(t, err)
assert.Equal(t, []byte("file"), got)
assert.Empty(t, claims.in)
}
func Test_readInput_doubleStdinErrors(t *testing.T) {
withStdin(t, bytes.NewBufferString("hello"))
var claims ioClaims
_, err := readInput("ca-key", "-", &claims)
require.NoError(t, err)
_, err = readInput("ca-crt", "-", &claims)
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
}
func Test_openInput_stdin(t *testing.T) {
withStdin(t, bytes.NewBufferString("hi"))
var claims ioClaims
r, err := openInput("ca", "-", &claims)
require.NoError(t, err)
defer r.Close()
b, err := io.ReadAll(r)
require.NoError(t, err)
assert.Equal(t, []byte("hi"), b)
}
func Test_openInput_doubleStdinErrors(t *testing.T) {
withStdin(t, bytes.NewBufferString("hi"))
var claims ioClaims
r, err := openInput("ca", "-", &claims)
require.NoError(t, err)
r.Close()
_, err = openInput("crt", "-", &claims)
require.EqualError(t, err, `-ca and -crt both set to "-", only one input may read from stdin`)
}
func Test_writeOutput_stdout(t *testing.T) {
out := &bytes.Buffer{}
err := writeOutput("-", []byte("payload"), 0600, out)
require.NoError(t, err)
assert.Equal(t, "payload", out.String())
}
func Test_writeOutput_file(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
out := &bytes.Buffer{}
err := writeOutput(p, []byte("payload"), 0600, out)
require.NoError(t, err)
assert.Empty(t, out.String())
got, err := os.ReadFile(p)
require.NoError(t, err)
assert.Equal(t, []byte("payload"), got)
}
func Test_reserveOutputs_noConflict(t *testing.T) {
var claims ioClaims
require.NoError(t, reserveOutputs(&claims,
"out-key", "/tmp/key",
"out-crt", "-",
"out-qr", "",
))
assert.Equal(t, "out-crt", claims.out)
}
func Test_reserveOutputs_conflict(t *testing.T) {
var claims ioClaims
err := reserveOutputs(&claims,
"out-key", "-",
"out-crt", "-",
)
require.EqualError(t, err, `-out-key and -out-crt both set to "-", only one output may write to stdout`)
}
func Test_reserveOutputs_panicsOnOddPairs(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
}()
var claims ioClaims
_ = reserveOutputs(&claims, "out-key")
}
func Test_reserveInputs_noConflict(t *testing.T) {
var claims ioClaims
require.NoError(t, reserveInputs(&claims,
"ca-key", "/tmp/ca.key",
"ca-crt", "-",
"in-pub", "",
))
assert.Equal(t, "ca-crt", claims.in)
}
func Test_reserveInputs_conflict(t *testing.T) {
var claims ioClaims
err := reserveInputs(&claims,
"ca-key", "-",
"ca-crt", "-",
)
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
}
func Test_claimIn_idempotent(t *testing.T) {
// pre-claim then a lazy re-claim of the same flag should be a no-op
var claims ioClaims
require.NoError(t, claims.claimIn("ca-key"))
require.NoError(t, claims.claimIn("ca-key"))
assert.Equal(t, "ca-key", claims.in)
}
func Test_claimOut_idempotent(t *testing.T) {
var claims ioClaims
require.NoError(t, claims.claimOut("out-crt"))
require.NoError(t, claims.claimOut("out-crt"))
assert.Equal(t, "out-crt", claims.out)
}
func Test_isStdio(t *testing.T) {
assert.True(t, isStdio("-"))
assert.False(t, isStdio(""))
assert.False(t, isStdio("./-"))
assert.False(t, isStdio("foo"))
}
+13 -4
View File
@@ -39,18 +39,26 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return err return err
} }
caFile, err := os.Open(*vf.caPath) var claims ioClaims
if err := reserveInputs(&claims,
"ca", *vf.caPath,
"crt", *vf.certPath,
); err != nil {
return err
}
caReader, err := openInput("ca", *vf.caPath, &claims)
if err != nil { if err != nil {
return fmt.Errorf("error while reading ca: %w", err) return fmt.Errorf("error while reading ca: %w", err)
} }
defer caFile.Close() defer caReader.Close()
caPool, err := cert.NewCAPoolFromPEMReader(caFile) caPool, err := cert.NewCAPoolFromPEMReader(caReader)
if err != nil && !errors.Is(err, cert.ErrExpired) { if err != nil && !errors.Is(err, cert.ErrExpired) {
return fmt.Errorf("error while adding ca cert to pool: %w", err) return fmt.Errorf("error while adding ca cert to pool: %w", err)
} }
rawCert, err := os.ReadFile(*vf.certPath) rawCert, err := readInput("crt", *vf.certPath, &claims)
if err != nil { if err != nil {
return fmt.Errorf("unable to read crt: %w", err) return fmt.Errorf("unable to read crt: %w", err)
} }
@@ -85,6 +93,7 @@ func verifySummary() string {
func verifyHelp(out io.Writer) { func verifyHelp(out io.Writer) {
vf := newVerifyFlags() vf := newVerifyFlags()
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
_, _ = out.Write([]byte(stdioHelpText))
vf.set.SetOutput(out) vf.set.SetOutput(out)
vf.set.PrintDefaults() vf.set.PrintDefaults()
} }
+44
View File
@@ -23,6 +23,7 @@ func Test_verifyHelp(t *testing.T) {
assert.Equal( assert.Equal(
t, t,
"Usage of "+os.Args[0]+" verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority.\n"+ "Usage of "+os.Args[0]+" verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority.\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -ca string\n"+ " -ca string\n"+
" \tRequired: path to a file containing one or more ca certificates\n"+ " \tRequired: path to a file containing one or more ca certificates\n"+
" -crt string\n"+ " -crt string\n"+
@@ -122,3 +123,46 @@ func Test_verify(t *testing.T) {
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
require.NoError(t, err) require.NoError(t, err)
} }
func Test_verify_stdio(t *testing.T) {
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
caPEM, _ := ca.MarshalPEM()
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
crtPEM, _ := crt.MarshalPEM()
caFile, err := os.CreateTemp("", "verify-ca")
require.NoError(t, err)
defer os.Remove(caFile.Name())
caFile.Write(caPEM)
// crt on stdin, ca on disk
withStdin(t, bytes.NewReader(crtPEM))
require.NoError(t, verify([]string{"-ca", caFile.Name(), "-crt", "-"}, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// ca on stdin, crt on disk
certFile, err := os.CreateTemp("", "verify-cert")
require.NoError(t, err)
defer os.Remove(certFile.Name())
certFile.Write(crtPEM)
withStdin(t, bytes.NewReader(caPEM))
ob.Reset()
eb.Reset()
require.NoError(t, verify([]string{"-ca", "-", "-crt", certFile.Name()}, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// both flags on stdin should error
withStdin(t, bytes.NewReader(caPEM))
ob.Reset()
eb.Reset()
require.EqualError(t, verify([]string{"-ca", "-", "-crt", "-"}, ob, eb),
`-ca and -crt both set to "-", only one input may read from stdin`)
}
+5 -2
View File
@@ -61,10 +61,13 @@ func main() {
} }
if *configPath == "" { if *configPath == "" {
fmt.Println("-config flag must be set") p, err := config.DefaultPath()
flag.Usage() if err != nil {
fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
*configPath = p
}
c := config.NewC(l) c := config.NewC(l)
err := c.Load(*configPath) err := c.Load(*configPath)
+2 -15
View File
@@ -3,8 +3,6 @@ package main
import ( import (
"fmt" "fmt"
"log" "log"
"os"
"path/filepath"
"github.com/kardianos/service" "github.com/kardianos/service"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
@@ -57,24 +55,13 @@ func (p *program) Stop(s service.Service) error {
return nil return nil
} }
func fileExists(filename string) bool {
_, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return true
}
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error { func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
if *configPath == "" { if *configPath == "" {
ex, err := os.Executable() p, err := config.DefaultPath()
if err != nil { if err != nil {
return err return err
} }
*configPath = filepath.Dir(ex) + "/config.yaml" *configPath = p
if !fileExists(*configPath) {
*configPath = filepath.Dir(ex) + "/config.yml"
}
} }
svcConfig := &service.Config{ svcConfig := &service.Config{
+5 -2
View File
@@ -50,10 +50,13 @@ func main() {
} }
if *configPath == "" { if *configPath == "" {
fmt.Println("-config flag must be set") p, err := config.DefaultPath()
flag.Usage() if err != nil {
fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
*configPath = p
}
l := logging.NewLogger(os.Stdout) l := logging.NewLogger(os.Stdout)
+29
View File
@@ -0,0 +1,29 @@
package config
import (
"fmt"
"os"
"path/filepath"
)
// DefaultPath returns a path to a config file alongside the running executable, preferring config.yaml over config.yml.
// If neither file exists an error is returned that names both paths checked.
func DefaultPath() (string, error) {
ex, err := os.Executable()
if err != nil {
return "", err
}
return defaultPathInDir(filepath.Dir(ex))
}
func defaultPathInDir(dir string) (string, error) {
yamlPath := filepath.Join(dir, "config.yaml")
if _, err := os.Stat(yamlPath); err == nil {
return yamlPath, nil
}
ymlPath := filepath.Join(dir, "config.yml")
if _, err := os.Stat(ymlPath); err == nil {
return ymlPath, nil
}
return "", fmt.Errorf("no default config found at %s or %s", yamlPath, ymlPath)
}
+67
View File
@@ -0,0 +1,67 @@
package config
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDefaultPathInDir(t *testing.T) {
t.Run("prefers config.yaml when both exist", func(t *testing.T) {
dir := t.TempDir()
want := filepath.Join(dir, "config.yaml")
other := filepath.Join(dir, "config.yml")
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
require.NoError(t, os.WriteFile(other, []byte("a: 2"), 0644))
got, err := defaultPathInDir(dir)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("returns config.yaml when only it exists", func(t *testing.T) {
dir := t.TempDir()
want := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
got, err := defaultPathInDir(dir)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("falls back to config.yml when only it exists", func(t *testing.T) {
dir := t.TempDir()
want := filepath.Join(dir, "config.yml")
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
got, err := defaultPathInDir(dir)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("errors when neither exists and names both paths", func(t *testing.T) {
dir := t.TempDir()
got, err := defaultPathInDir(dir)
assert.Empty(t, got)
require.Error(t, err)
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yaml"))
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yml"))
})
}
func TestDefaultPath(t *testing.T) {
got, err := DefaultPath()
if err != nil {
ex, exErr := os.Executable()
require.NoError(t, exErr)
assert.Contains(t, err.Error(), filepath.Dir(ex))
return
}
ex, err := os.Executable()
require.NoError(t, err)
assert.Equal(t, filepath.Dir(ex), filepath.Dir(got))
assert.Contains(t, []string{"config.yaml", "config.yml"}, filepath.Base(got))
}
+7 -37
View File
@@ -11,7 +11,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@@ -45,8 +44,6 @@ type connectionManager struct {
inactivityTimeout atomic.Int64 inactivityTimeout atomic.Int64
dropInactive atomic.Bool dropInactive atomic.Bool
metricsTxPunchy metrics.Counter
l *slog.Logger l *slog.Logger
} }
@@ -57,7 +54,6 @@ func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p
punchy: p, punchy: p,
relayUsed: make(map[uint32]struct{}), relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{}, relayUsedLock: &sync.RWMutex{},
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
} }
cm.reload(c, true) cm.reload(c, true)
@@ -369,7 +365,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if !outTraffic { if !outTraffic {
// Send a punch packet to keep the NAT state alive // Send a punch packet to keep the NAT state alive
cm.sendPunch(hostinfo) cm.punchy.SendPunch(hostinfo)
} }
return decision, hostinfo, primary return decision, hostinfo, primary
@@ -400,17 +396,16 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so. // Just maintain NAT state if configured to do so.
cm.sendPunch(hostinfo) cm.punchy.SendPunch(hostinfo)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
return doNothing, nil, nil return doNothing, nil, nil
} }
if cm.punchy.GetTargetEverything() { // We aren't receiving traffic but we are sending it. The outbound
// This is similar to the old punchy behavior with a slight optimization. // traffic itself refreshes the primary remote's NAT state; this
// We aren't receiving traffic but we are sending it, punch on all known // fans out to non-primary remotes, but only if target_all_remotes
// ips in case we need to re-prime NAT state // is configured.
cm.sendPunch(hostinfo) cm.punchy.SendPunchToAll(hostinfo)
}
if cm.l.Enabled(context.Background(), slog.LevelDebug) { if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Tunnel status", hostinfo.logger(cm.l).Debug("Tunnel status",
@@ -512,31 +507,6 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
} }
} }
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
if !cm.punchy.GetPunch() {
// Punching is disabled
return
}
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
// would lose the ability to notify us and punchy.respond would become unreliable.
return
}
if cm.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, addr)
})
} else if hostinfo.remote.IsValid() {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState() cs := cm.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert curCrt := hostinfo.ConnectionState.myCert
+4 -4
View File
@@ -64,7 +64,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Create manager // Create manager
conf := config.NewC(test.NewLogger()) conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce nc.intf = ifce
p := []byte("") p := []byte("")
@@ -146,7 +146,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Create manager // Create manager
conf := config.NewC(test.NewLogger()) conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce nc.intf = ifce
p := []byte("") p := []byte("")
@@ -233,7 +233,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
conf.Settings["tunnels"] = map[string]any{ conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true, "drop_inactive": true,
} }
punchy := NewPunchyFromConfig(test.NewLogger(), conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load()) assert.True(t, nc.dropInactive.Load())
nc.intf = ifce nc.intf = ifce
@@ -358,7 +358,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
// Create manager // Create manager
conf := config.NewC(test.NewLogger()) conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce nc.intf = ifce
ifce.connectionManager = nc ifce.connectionManager = nc
+5 -4
View File
@@ -7,13 +7,14 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/handshake" "github.com/slackhq/nebula/handshake"
"github.com/slackhq/nebula/noiseutil"
) )
const ReplayWindow = 1024 const ReplayWindow = 1024
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey noiseutil.CipherState
dKey *NebulaCipherState dKey noiseutil.CipherState
myCert cert.Certificate myCert cert.Certificate
peerCert *cert.CachedCertificate peerCert *cert.CachedCertificate
initiator bool initiator bool
@@ -31,8 +32,8 @@ func newConnectionStateFromResult(r *handshake.Result) *ConnectionState {
myCert: r.MyCert, myCert: r.MyCert,
initiator: r.Initiator, initiator: r.Initiator,
peerCert: r.RemoteCert, peerCert: r.RemoteCert,
eKey: NewNebulaCipherState(r.EKey), eKey: noiseutil.NewCipherState(r.EKey, r.Cipher),
dKey: NewNebulaCipherState(r.DKey), dKey: noiseutil.NewCipherState(r.DKey, r.Cipher),
window: NewBits(ReplayWindow), window: NewBits(ReplayWindow),
} }
ci.messageCounter.Add(r.MessageIndex) ci.messageCounter.Add(r.MessageIndex)
+12 -60
View File
@@ -5,8 +5,6 @@ package nebula
import ( import (
"net/netip" "net/netip"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
@@ -22,7 +20,9 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message
panic(err) panic(err)
} }
pipeTo.InjectUDPPacket(p) pipeTo.InjectUDPPacket(p)
if h.Type == msgType && h.Subtype == subType { match := h.Type == msgType && h.Subtype == subType
p.Release()
if match {
return return
} }
} }
@@ -38,7 +38,9 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
panic(err) panic(err)
} }
pipeTo.InjectUDPPacket(p) pipeTo.InjectUDPPacket(p)
if h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType { match := h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType
p.Release()
if match {
return return
} }
} }
@@ -90,65 +92,15 @@ func (c *Control) GetTunTxChan() <-chan []byte {
return c.f.inside.(*overlay.TestTun).TxPackets return c.f.inside.(*overlay.TestTun).TxPackets
} }
// InjectUDPPacket will inject a packet into the udp side of nebula // InjectUDPPacket injects a packet into the udp side. We copy internally so the caller keeps ownership of p.
// The copy comes from the freelist so steady-state alloc is zero.
func (c *Control) InjectUDPPacket(p *udp.Packet) { func (c *Control) InjectUDPPacket(p *udp.Packet) {
c.f.outside.(*udp.TesterConn).Send(p) c.f.outside.(*udp.TesterConn).Send(p.Copy())
} }
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol // InjectTunPacket pushes an IP packet onto the tun interface.
func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { func (c *Control) InjectTunPacket(packet []byte) {
serialize := make([]gopacket.SerializableLayer, 0) c.f.inside.(*overlay.TestTun).Send(packet)
var netLayer gopacket.NetworkLayer
if toAddr.Is6() {
if !fromAddr.Is6() {
panic("Cant send ipv6 to ipv4")
}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
}
udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
err := udp.SetNetworkLayerForChecksum(netLayer)
if err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
serialize = append(serialize, &udp, gopacket.Payload(data))
err = gopacket.SerializeLayers(buffer, opt, serialize...)
if err != nil {
panic(err)
}
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
} }
func (c *Control) GetVpnAddrs() []netip.Addr { func (c *Control) GetVpnAddrs() []netip.Addr {
+82 -16
View File
@@ -11,7 +11,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/gaissmai/bart"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
@@ -23,7 +22,10 @@ type dnsServer struct {
dnsMap4 map[string]netip.Addr dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr dnsMap6 map[string]netip.Addr
hostMap *HostMap hostMap *HostMap
myVpnAddrsTable *bart.Lite pki *PKI
// selfHost is the cached FQDN we last seeded for ourselves
selfHost string
mux *dns.ServeMux mux *dns.ServeMux
@@ -55,14 +57,14 @@ type dnsServer struct {
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel // they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
// watcher that tears the listener down on nebula shutdown. The returned // watcher that tears the listener down on nebula shutdown. The returned
// pointer is always non-nil, even on error. // pointer is always non-nil, even on error.
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*dnsServer, error) {
ds := &dnsServer{ ds := &dnsServer{
l: l, l: l,
ctx: ctx, ctx: ctx,
dnsMap4: make(map[string]netip.Addr), dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap, hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable, pki: pki,
} }
ds.mux = dns.NewServeMux() ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest) ds.mux.HandleFunc(".", ds.handleDnsRequest)
@@ -76,6 +78,7 @@ func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState,
if err := ds.reload(c, true); err != nil { if err := ds.reload(c, true); err != nil {
return ds, err return ds, err
} }
ds.seedSelf()
return ds, nil return ds, nil
} }
@@ -113,7 +116,7 @@ func (d *dnsServer) reload(c *config.C, initial bool) error {
d.Stop() d.Stop()
} }
// Drop any records that accumulated while enabled; a later re-enable // Drop any records that accumulated while enabled; a later re-enable
// will repopulate from fresh handshakes. // will repopulate from fresh handshakes and a fresh seedSelf.
d.clearRecords() d.clearRecords()
return nil return nil
} }
@@ -121,17 +124,14 @@ func (d *dnsServer) reload(c *config.C, initial bool) error {
if running == nil { if running == nil {
// Was disabled (or never started); bring it up now. // Was disabled (or never started); bring it up now.
go d.Start() go d.Start()
return nil } else if !sameAddr {
}
if sameAddr {
return nil
}
d.shutdownServer(running, runningStarted, "reload") d.shutdownServer(running, runningStarted, "reload")
// Old Start goroutine has now exited; bring up a fresh listener on the // Old Start goroutine has now exited; bring up a fresh listener on the new address.
// new address.
go d.Start() go d.Start()
}
// Refresh the self entry every enabled reload so cert renewals that change our name or VPN addresses are picked up.
d.seedSelf()
return nil return nil
} }
@@ -249,6 +249,20 @@ func (d *dnsServer) QueryCert(data string) string {
return "" return ""
} }
// The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves.
// Answer self lookups straight from the local cert state.
if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) {
c := cs.GetDefaultCertificate()
if c == nil {
return ""
}
b, err := c.MarshalJSON()
if err != nil {
return ""
}
return string(b)
}
hostinfo := d.hostMap.QueryVpnAddr(ip) hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil { if hostinfo == nil {
return "" return ""
@@ -266,12 +280,60 @@ func (d *dnsServer) QueryCert(data string) string {
return string(b) return string(b)
} }
// clearRecords drops all DNS records. // clearRecords drops all DNS records, including the self entry.
func (d *dnsServer) clearRecords() { func (d *dnsServer) clearRecords() {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
clear(d.dnsMap4) clear(d.dnsMap4)
clear(d.dnsMap6) clear(d.dnsMap6)
d.selfHost = ""
}
// seedSelf inserts (or refreshes) a record for our own cert name pointing at our VPN addresses,
// so a single-lighthouse network can resolve the lighthouse's own hostname without the two-process workaround.
func (d *dnsServer) seedSelf() {
if !d.enabled.Load() {
return
}
cs := d.certState()
if cs == nil {
return
}
c := cs.GetDefaultCertificate()
if c == nil {
return
}
newHost := strings.ToLower(c.Name()) + "."
d.Lock()
defer d.Unlock()
if d.selfHost != "" && d.selfHost != newHost {
delete(d.dnsMap4, d.selfHost)
delete(d.dnsMap6, d.selfHost)
}
d.selfHost = newHost
delete(d.dnsMap4, newHost)
delete(d.dnsMap6, newHost)
haveV4, haveV6 := false, false
for _, addr := range cs.myVpnAddrs {
if addr.Is4() && !haveV4 {
d.dnsMap4[newHost] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[newHost] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
}
func (d *dnsServer) certState() *CertState {
if d.pki == nil {
return nil
}
return d.pki.getCertState()
} }
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
@@ -309,8 +371,12 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
return true return true
} }
cs := d.certState()
if cs == nil || cs.myVpnAddrsTable == nil {
return false
}
//if we found it in this table, it's good //if we found it in this table, it's good
return d.myVpnAddrsTable.Contains(b) return cs.myVpnAddrsTable.Contains(b)
} }
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
+89
View File
@@ -9,7 +9,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/gaissmai/bart"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -276,6 +279,92 @@ func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
} }
} }
// newTestPKI builds a minimal *PKI with a single v1 cert whose name and
// VPN addresses are caller-provided, suitable for exercising seedSelf and
// QueryCert self handling.
func newTestPKI(t *testing.T, name string, addrs []netip.Addr) *PKI {
t.Helper()
networks := make([]netip.Prefix, 0, len(addrs))
for _, a := range addrs {
bits := 32
if a.Is6() {
bits = 128
}
networks = append(networks, netip.PrefixFrom(a, bits))
}
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
c, _, _, _ := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, ca, caKey, name, time.Time{}, time.Time{}, networks, nil, nil)
addrsTable := new(bart.Lite)
for _, a := range addrs {
addrsTable.Insert(netip.PrefixFrom(a, a.BitLen()))
}
cs := &CertState{
v2Cert: c,
initiatingVersion: cert.Version2,
myVpnAddrs: addrs,
myVpnAddrsTable: addrsTable,
}
pki := &PKI{}
pki.cs.Store(cs)
return pki
}
func TestDnsServer_seedSelf_addsOwnRecord(t *testing.T) {
ds, c := newTestDnsServer(t)
myV4 := netip.MustParseAddr("10.0.0.1")
myV6 := netip.MustParseAddr("fd00::1")
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4, myV6})
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
ds.seedSelf()
got4, exists := ds.Query(dns.TypeA, "lighthouse.")
assert.True(t, exists)
assert.Equal(t, myV4, got4)
got6, exists := ds.Query(dns.TypeAAAA, "lighthouse.")
assert.True(t, exists)
assert.Equal(t, myV6, got6)
}
func TestDnsServer_seedSelf_disabled_noOp(t *testing.T) {
ds, c := newTestDnsServer(t)
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")})
setDnsConfig(c, "127.0.0.1", "0", true, false)
require.NoError(t, ds.reload(c, true))
ds.seedSelf()
_, exists := ds.Query(dns.TypeA, "lighthouse.")
assert.False(t, exists)
}
func TestDnsServer_clearRecords_dropsSelfHost(t *testing.T) {
ds, c := newTestDnsServer(t)
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")})
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
ds.seedSelf()
require.NotEmpty(t, ds.selfHost)
ds.clearRecords()
assert.Empty(t, ds.selfHost)
_, exists := ds.Query(dns.TypeA, "lighthouse.")
assert.False(t, exists)
}
func TestDnsServer_QueryCert_returnsOwnCert(t *testing.T) {
ds, _ := newTestDnsServer(t)
myV4 := netip.MustParseAddr("10.0.0.1")
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4})
got := ds.QueryCert(myV4.String() + ".")
assert.NotEmpty(t, got, "TXT lookup of our own VPN address should return our cert")
other := netip.MustParseAddr("10.0.0.99")
assert.Empty(t, ds.QueryCert(other.String()+"."), "unknown peer IP should return nothing")
}
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) { func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
port := freeUDPPort(t) port := freeUDPPort(t)
ds, c := newTestDnsServer(t) ds, c := newTestDnsServer(t)
+12 -12
View File
@@ -47,7 +47,7 @@ func TestHandshakeRetransmitDuplicate(t *testing.T) {
defer r.RenderFlow() defer r.RenderFlow()
t.Log("Trigger handshake from me to them") t.Log("Trigger handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
t.Log("Grab my msg1") t.Log("Grab my msg1")
msg1 := myControl.GetFromUDP(true) msg1 := myControl.GetFromUDP(true)
@@ -97,7 +97,7 @@ func TestHandshakeTruncatedPacketRecovery(t *testing.T) {
defer r.RenderFlow() defer r.RenderFlow()
t.Log("Trigger handshake") t.Log("Trigger handshake")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
t.Log("Get msg1 and deliver to responder") t.Log("Get msg1 and deliver to responder")
msg1 := myControl.GetFromUDP(true) msg1 := myControl.GetFromUDP(true)
@@ -146,7 +146,7 @@ func TestHandshakeOrphanedMsg2Dropped(t *testing.T) {
defer r.RenderFlow() defer r.RenderFlow()
t.Log("Complete a normal handshake") t.Log("Complete a normal handshake")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
r.RouteForAllUntilTxTun(theirControl) r.RouteForAllUntilTxTun(theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
@@ -248,7 +248,7 @@ func TestHandshakeLateResponse(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger handshake from me") t.Log("Trigger handshake from me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
t.Log("Grab msg1 but don't deliver") t.Log("Grab msg1 but don't deliver")
msg1 := myControl.GetFromUDP(true) msg1 := myControl.GetFromUDP(true)
@@ -292,7 +292,7 @@ func TestHandshakeSelfConnectionRejected(t *testing.T) {
myControl.Start() myControl.Start()
t.Log("Trigger handshake from me") t.Log("Trigger handshake from me")
myControl.InjectTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) myControl.InjectTunPacket(BuildTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
msg1 := myControl.GetFromUDP(true) msg1 := myControl.GetFromUDP(true)
t.Log("Drain any handshake retransmits before injecting") t.Log("Drain any handshake retransmits before injecting")
@@ -375,7 +375,7 @@ func TestHandshakeRemoteAllowList(t *testing.T) {
defer r.RenderFlow() defer r.RenderFlow()
t.Log("Trigger handshake from them") t.Log("Trigger handshake from them")
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi")))
msg1 := theirControl.GetFromUDP(true) msg1 := theirControl.GetFromUDP(true)
t.Log("Rewrite the source to a blocked IP and inject") t.Log("Rewrite the source to a blocked IP and inject")
@@ -426,7 +426,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) {
defer r.RenderFlow() defer r.RenderFlow()
t.Log("Complete a normal handshake via the router") t.Log("Complete a normal handshake via the router")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
r.RouteForAllUntilTxTun(theirControl) r.RouteForAllUntilTxTun(theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
@@ -437,7 +437,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) {
originalRemote := hi.CurrentRemote originalRemote := hi.CurrentRemote
t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)") t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam")))
r.RouteForAllUntilTxTun(theirControl) r.RouteForAllUntilTxTun(theirControl)
t.Log("Verify tunnel still works") t.Log("Verify tunnel still works")
@@ -475,8 +475,8 @@ func TestHandshakeWrongResponderPacketStore(t *testing.T) {
evilControl.Start() evilControl.Start()
t.Log("Send multiple packets to them (cached during handshake)") t.Log("Send multiple packets to them (cached during handshake)")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1")))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2")))
t.Log("Route until evil tunnel is closed") t.Log("Route until evil tunnel is closed")
h := &header.H{} h := &header.H{}
@@ -540,7 +540,7 @@ func TestHandshakeRelayComplete(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger handshake via relay") t.Log("Trigger handshake via relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay")))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -568,7 +568,7 @@ func TestHandshakeRelayComplete(t *testing.T) {
} }
// NOTE: Relay V1 cert + IPv6 rejection is not tested here because // NOTE: Relay V1 cert + IPv6 rejection is not tested here because
// InjectTunUDPPacket from a V4 node to a V6 address panics in the test // BuildTunUDPPacket from a V4 node to a V6 address panics in the test
// framework. The check is in handshake_manager.go handleOutbound relay // framework. The check is in handshake_manager.go handleOutbound relay
// logic (lines ~304-313): if the relay host has a V1 cert and either // logic (lines ~304-313): if the relay host has a V1 cert and either
// address is IPv6, the relay is skipped. // address is IPv6, the relay is skipped.
+46 -30
View File
@@ -16,6 +16,7 @@ import (
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -39,11 +40,22 @@ func BenchmarkHotPath(b *testing.B) {
r.CancelFlowLogs() r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
// Pre-build the IP packet bytes once so the bench measures the data plane,
// not gopacket SerializeLayers overhead.
prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
// EnableFanIn switches the router to a 0-alloc routing path. Required
// for hot-path benchmarks; would conflict with GetFromUDP-using tests.
r.EnableFanIn()
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(prebuilt)
_ = r.RouteForAllUntilTxTun(theirControl) // Release the TUN-side bytes back to the harness freelist; the bench
// just confirms a packet arrived, the contents aren't inspected.
overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl))
} }
myControl.Stop() myControl.Stop()
@@ -71,11 +83,15 @@ func BenchmarkHotPathRelay(b *testing.B) {
theirControl.Start() theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
r.EnableFanIn()
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(prebuilt)
_ = r.RouteForAllUntilTxTun(theirControl) overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl))
} }
myControl.Stop() myControl.Stop()
@@ -97,7 +113,7 @@ func TestGoodHandshake(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
t.Log("Have them consume my stage 0 packet. They have a tunnel now") t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -191,7 +207,7 @@ func TestWrongResponderHandshake(t *testing.T) {
evilControl.Start() evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed") t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
h := &header.H{} h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -273,7 +289,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
evilControl.Start() evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed") t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
h := &header.H{} h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -352,8 +368,8 @@ func TestStage1Race(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake to start on both me and them") t.Log("Trigger a handshake to start on both me and them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
t.Log("Get both stage 1 handshake packets") t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true) myHsForThem := myControl.GetFromUDP(true)
@@ -430,7 +446,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
theirControl.Start() theirControl.Start()
r.Log("Trigger a handshake from me to them") r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -441,7 +457,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")))
p = r.RouteForAllUntilTxTun(theirControl) p = r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -480,7 +496,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirControl.Start() theirControl.Start()
r.Log("Trigger a handshake from me to them") r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -492,7 +508,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")))
p = r.RouteForAllUntilTxTun(myControl) p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
@@ -535,7 +551,7 @@ func TestRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
@@ -565,7 +581,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
@@ -595,14 +611,14 @@ func TestReestablishRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Ensure packet traversal from them to me via the relay") t.Log("Ensure packet traversal from them to me via the relay")
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
p = r.RouteForAllUntilTxTun(myControl) p = r.RouteForAllUntilTxTun(myControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
@@ -617,7 +633,7 @@ func TestReestablishRelays(t *testing.T) {
for curIndexes >= start { for curIndexes >= start {
curIndexes = len(myControl.GetHostmap().Indexes) curIndexes = len(myControl.GetHostmap().Indexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes) r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
return router.RouteAndExit return router.RouteAndExit
@@ -634,7 +650,7 @@ func TestReestablishRelays(t *testing.T) {
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p = r.RouteForAllUntilTxTun(theirControl) p = r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
@@ -669,7 +685,7 @@ func TestReestablishRelays(t *testing.T) {
t.Log("Assert the tunnel works the other way, too") t.Log("Assert the tunnel works the other way, too")
for { for {
t.Log("RouteForAllUntilTxTun") t.Log("RouteForAllUntilTxTun")
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
p = r.RouteForAllUntilTxTun(myControl) p = r.RouteForAllUntilTxTun(myControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
@@ -739,8 +755,8 @@ func TestStage1RaceRelays(t *testing.T) {
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me") r.Log("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
r.Log("Wait for a packet from them to me") r.Log("Wait for a packet from them to me")
p := r.RouteForAllUntilTxTun(myControl) p := r.RouteForAllUntilTxTun(myControl)
@@ -787,8 +803,8 @@ func TestStage1RaceRelays2(t *testing.T) {
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me") r.Log("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@@ -852,7 +868,7 @@ func TestRehandshakingRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
@@ -957,7 +973,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
@@ -1259,8 +1275,8 @@ func TestRaceRegression(t *testing.T) {
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
t.Log("Start both handshakes") t.Log("Start both handshakes")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
t.Log("Get both stage 1") t.Log("Get both stage 1")
myStage1ForThem := myControl.GetFromUDP(true) myStage1ForThem := myControl.GetFromUDP(true)
@@ -1476,7 +1492,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
t.Log("Have them consume my stage 0 packet. They have a tunnel now") t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -1504,7 +1520,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
//reply //reply
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")))
//wait for reply //wait for reply
theirControl.WaitForType(1, 0, myControl) theirControl.WaitForType(1, 0, myControl)
theirCachedPacket := myControl.GetFromTun(true) theirCachedPacket := myControl.GetFromTun(true)
+57 -2
View File
@@ -294,12 +294,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me // Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) controlB.InjectTunPacket(BuildTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")))
bPacket := r.RouteForAllUntilTxTun(controlA) bPacket := r.RouteForAllUntilTxTun(controlA)
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
// And once more from me to them // And once more from me to them
controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) controlA.InjectTunPacket(BuildTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")))
aPacket := r.RouteForAllUntilTxTun(controlB) aPacket := r.RouteForAllUntilTxTun(controlB)
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
} }
@@ -408,3 +408,58 @@ func testLogLevelName() string {
} }
return "info" return "info"
} }
// BuildTunUDPPacket assembles an IP+UDP packet suitable for Control.InjectTunPacket.
// Using UDP here because it's a simpler protocol.
func BuildTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) []byte {
serialize := make([]gopacket.SerializableLayer, 0)
var netLayer gopacket.NetworkLayer
if toAddr.Is6() {
if !fromAddr.Is6() {
panic("Cant send ipv6 to ipv4")
}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
}
udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
if err := udp.SetNetworkLayerForChecksum(netLayer); err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
serialize = append(serialize, &udp, gopacket.Payload(data))
if err := gopacket.SerializeLayers(buffer, opt, serialize...); err != nil {
panic(err)
}
return buffer.Bytes()
}
+3 -7
View File
@@ -18,14 +18,10 @@ import (
// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain // retry mechanism gives the wg.Wait()-driven goroutines a moment to drain
// before failing the assertion. // before failing the assertion.
// //
// IgnoreCurrent is necessary in the parallelized suite: other tests can // Intentionally NOT t.Parallel()'d: concurrent tests would have their own
// leave goroutines mid-shutdown when this one runs (Stop is async, the // goroutines running and trip the assertion.
// wg.Wait() drain is not blocking on test return). We're checking that
// *this* test's setup tears down cleanly, not that the whole suite is
// idle at this moment. Intentionally NOT t.Parallel()'d for the same
// reason — concurrent test goroutines would always show up.
func TestNoGoroutineLeaks(t *testing.T) { func TestNoGoroutineLeaks(t *testing.T) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) defer goleak.VerifyNone(t)
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
+188 -54
View File
@@ -13,6 +13,7 @@ import (
"regexp" "regexp"
"sort" "sort"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -24,6 +25,19 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
// outNatKey is the (from, to) pair used by outNat. Comparable struct, so it works as a map key without the
// allocation cost of a string-concat key.
type outNatKey struct {
from, to netip.AddrPort
}
// fannedPacket pairs a UDP TX packet with its source control so the router can route it after popping from
// the fan-in channel.
type fannedPacket struct {
from *nebula.Control
pkt *udp.Packet
}
type R struct { type R struct {
// Simple map of the ip:port registered on a control to the control // Simple map of the ip:port registered on a control to the control
// Basically a router, right? // Basically a router, right?
@@ -34,12 +48,28 @@ type R struct {
// A last used map, if an inbound packet hit the inNat map then // A last used map, if an inbound packet hit the inNat map then
// all return packets should use the same last used inbound address for the outbound sender // all return packets should use the same last used inbound address for the outbound sender
// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver outNat map[outNatKey]netip.AddrPort
outNat map[string]netip.AddrPort
// A map of vpn ip to the nebula control it belongs to // A map of vpn ip to the nebula control it belongs to
vpnControls map[netip.Addr]*nebula.Control vpnControls map[netip.Addr]*nebula.Control
// Cached select infrastructure for RouteForAllUntilTxTun.
// The controls map is immutable after NewR so the cases are good for the test lifetime.
// We only rebuild if a different receiver is asked.
selRecvCtl *nebula.Control
selCases []reflect.SelectCase
selCtls []*nebula.Control
// Optional fan-in mode for hot-path benchmarks: one forwarder goroutine per control drains UDP TX into udpFanIn,
// so RouteForAllUntilTxTun can do a fixed 2-way native select instead of paying reflect.Select per call.
// Off by default (would otherwise interleave with tests that use GetFromUDP directly on the same control).
// Enabled by EnableFanIn.
udpFanIn chan fannedPacket
stopFanIn chan struct{}
fanInWG sync.WaitGroup
fanInMu sync.Mutex
fanInOn atomic.Bool
ignoreFlows []ignoreFlow ignoreFlows []ignoreFlow
flow []flowEntry flow []flowEntry
@@ -119,7 +149,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
controls: make(map[netip.AddrPort]*nebula.Control), controls: make(map[netip.AddrPort]*nebula.Control),
vpnControls: make(map[netip.Addr]*nebula.Control), vpnControls: make(map[netip.Addr]*nebula.Control),
inNat: make(map[netip.AddrPort]*nebula.Control), inNat: make(map[netip.AddrPort]*nebula.Control),
outNat: make(map[string]netip.AddrPort), outNat: make(map[outNatKey]netip.AddrPort),
flow: []flowEntry{}, flow: []flowEntry{},
ignoreFlows: []ignoreFlow{}, ignoreFlows: []ignoreFlow{},
fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
@@ -153,8 +183,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-clockSource.C: case <-clockSource.C:
r.Lock()
r.renderHostmaps("clock tick") r.renderHostmaps("clock tick")
r.renderFlow() r.renderFlow()
r.Unlock()
} }
} }
}() }()
@@ -180,15 +212,21 @@ func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
// RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening. // RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening.
func (r *R) RenderFlow() { func (r *R) RenderFlow() {
r.cancelRender() r.cancelRender()
r.Lock()
defer r.Unlock()
r.renderFlow() r.renderFlow()
} }
// CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected // CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected
func (r *R) CancelFlowLogs() { func (r *R) CancelFlowLogs() {
r.cancelRender() r.cancelRender()
r.Lock()
r.flow = nil r.flow = nil
r.Unlock()
} }
// renderFlow writes the flow log to disk. Caller must hold r.Lock. renderFlow reads r.flow / r.additionalGraphs and
// the *packet pointers stashed inside, all of which are mutated under the same lock by routing paths.
func (r *R) renderFlow() { func (r *R) renderFlow() {
if r.flow == nil { if r.flow == nil {
return return
@@ -434,68 +472,157 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
panic("No control for udp tx " + a.String()) panic("No control for udp tx " + a.String())
} }
fp := r.unlockedInjectFlow(sender, c, p, false) fp := r.unlockedInjectFlow(sender, c, p, false)
c.InjectUDPPacket(p) c.InjectUDPPacket(p) // copies internally; original is ours to release
fp.WasReceived() fp.WasReceived()
r.Unlock() r.Unlock()
p.Release()
} }
} }
} }
// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on receivers tun // RouteForAllUntilTxTun will route for everyone and return when a packet is seen on the receiver's tun.
// If the router doesn't have the nebula controller for that address, we panic // If a control's UDP TX address can't be matched to a registered control, we panic.
//
// For allocation-sensitive callers (hot-path benchmarks, in particular relay
// benches with 3+ controls), call EnableFanIn() first.
func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
if r.fanInOn.Load() {
return r.routeFanIn(receiver)
}
return r.routeReflect(receiver)
}
// routeFanIn is the alloc-free path used when EnableFanIn is in effect.
func (r *R) routeFanIn(receiver *nebula.Control) []byte {
tunTx := receiver.GetTunTxChan()
for {
select {
case p := <-tunTx:
r.Lock()
if r.flow != nil {
np := udp.Packet{Data: make([]byte, len(p))}
copy(np.Data, p)
r.unlockedInjectFlow(receiver, receiver, &np, true)
}
r.Unlock()
return p
case fp := <-r.udpFanIn:
r.routeUDP(fp.from, fp.pkt)
}
}
}
// routeReflect is the default reflect.Select-based path. Pays the boxing allocation per call but doesn't interfere
// with tests that pull packets directly from controls' UDP TX channels via GetFromUDP.
func (r *R) routeReflect(receiver *nebula.Control) []byte {
sc, cm := r.selectCasesFor(receiver)
for {
x, rx, _ := reflect.Select(sc)
if x == 0 {
p := rx.Interface().([]byte)
r.Lock()
if r.flow != nil {
np := udp.Packet{Data: make([]byte, len(p))}
copy(np.Data, p)
r.unlockedInjectFlow(cm[x], cm[x], &np, true)
}
r.Unlock()
return p
}
r.routeUDP(cm[x], rx.Interface().(*udp.Packet))
}
}
// EnableFanIn switches RouteForAllUntilTxTun to the alloc-free fan-in path.
// One forwarder goroutine per registered control drains UDP TX into a shared channel that RouteForAllUntilTxTun selects
// on alongside the receiver's TUN TX channel.
func (r *R) EnableFanIn() {
r.fanInMu.Lock()
defer r.fanInMu.Unlock()
if r.fanInOn.Load() {
return
}
r.udpFanIn = make(chan fannedPacket, 32)
r.stopFanIn = make(chan struct{})
for _, c := range r.controls {
r.startFanInWorker(c)
}
r.fanInOn.Store(true)
r.t.Cleanup(r.stopFanInWorkers)
}
// startFanInWorker spawns a goroutine that drains c's UDP TX into r.udpFanIn.
func (r *R) startFanInWorker(c *nebula.Control) {
r.fanInWG.Add(1)
udpTx := c.GetUDPTxChan()
go func() {
defer r.fanInWG.Done()
for {
select {
case <-r.stopFanIn:
return
case p := <-udpTx:
select {
case <-r.stopFanIn:
p.Release()
return
case r.udpFanIn <- fannedPacket{from: c, pkt: p}:
}
}
}
}()
}
// stopFanInWorkers signals the fan-in goroutines to exit and waits for them.
func (r *R) stopFanInWorkers() {
r.fanInMu.Lock()
wasOn := r.fanInOn.Swap(false)
r.fanInMu.Unlock()
if !wasOn {
return
}
close(r.stopFanIn)
r.fanInWG.Wait()
}
// routeUDP forwards a UDP TX packet from the named source control to the destination control derived from p.To,
// releasing the source packet after InjectUDPPacket has copied its bytes into a fresh pool slot.
func (r *R) routeUDP(from *nebula.Control, p *udp.Packet) {
r.Lock()
defer r.Unlock()
a := from.GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
panic(fmt.Sprintf("No control for udp tx %s", p.To))
}
fp := r.unlockedInjectFlow(from, c, p, false)
c.InjectUDPPacket(p) // copies internally; original is ours to release
fp.WasReceived()
p.Release()
}
// selectCasesFor returns the SelectCase array used by routeReflect: one slot for the receiver's TUN TX channel followed
// by one per control's UDP TX channel. Cached for the test lifetime, only rebuilt if the receiver changes.
func (r *R) selectCasesFor(receiver *nebula.Control) ([]reflect.SelectCase, []*nebula.Control) {
r.Lock()
defer r.Unlock()
if r.selRecvCtl == receiver && r.selCases != nil {
return r.selCases, r.selCtls
}
sc := make([]reflect.SelectCase, len(r.controls)+1) sc := make([]reflect.SelectCase, len(r.controls)+1)
cm := make([]*nebula.Control, len(r.controls)+1) cm := make([]*nebula.Control, len(r.controls)+1)
sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(receiver.GetTunTxChan())}
i := 0 cm[0] = receiver
sc[i] = reflect.SelectCase{ i := 1
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(receiver.GetTunTxChan()),
Send: reflect.Value{},
}
cm[i] = receiver
i++
for _, c := range r.controls { for _, c := range r.controls {
sc[i] = reflect.SelectCase{ sc[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan())}
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(c.GetUDPTxChan()),
Send: reflect.Value{},
}
cm[i] = c cm[i] = c
i++ i++
} }
r.selRecvCtl = receiver
for { r.selCases = sc
x, rx, _ := reflect.Select(sc) r.selCtls = cm
r.Lock() return sc, cm
if x == 0 {
// we are the tun tx, we can exit
p := rx.Interface().([]byte)
np := udp.Packet{Data: make([]byte, len(p))}
copy(np.Data, p)
r.unlockedInjectFlow(cm[x], cm[x], &np, true)
r.Unlock()
return p
} else {
// we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet)
a := cm[x].GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
r.Unlock()
panic(fmt.Sprintf("No control for udp tx %s", p.To))
}
fp := r.unlockedInjectFlow(cm[x], c, p, false)
c.InjectUDPPacket(p)
fp.WasReceived()
}
r.Unlock()
}
} }
// RouteExitFunc will call the whatDo func with each udp packet from sender. // RouteExitFunc will call the whatDo func with each udp packet from sender.
@@ -522,6 +649,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
switch e { switch e {
case ExitNow: case ExitNow:
r.Unlock() r.Unlock()
p.Release()
return return
case RouteAndExit: case RouteAndExit:
@@ -529,6 +657,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
receiver.InjectUDPPacket(p) receiver.InjectUDPPacket(p)
fp.WasReceived() fp.WasReceived()
r.Unlock() r.Unlock()
p.Release()
return return
case KeepRouting: case KeepRouting:
@@ -541,6 +670,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
} }
r.Unlock() r.Unlock()
p.Release()
} }
} }
@@ -641,6 +771,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
switch e { switch e {
case ExitNow: case ExitNow:
r.Unlock() r.Unlock()
p.Release()
return return
case RouteAndExit: case RouteAndExit:
@@ -648,6 +779,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
receiver.InjectUDPPacket(p) receiver.InjectUDPPacket(p)
fp.WasReceived() fp.WasReceived()
r.Unlock() r.Unlock()
p.Release()
return return
case KeepRouting: case KeepRouting:
@@ -659,6 +791,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
panic(fmt.Sprintf("Unknown exitFunc return: %v", e)) panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
} }
r.Unlock() r.Unlock()
p.Release()
} }
} }
@@ -702,19 +835,20 @@ func (r *R) FlushAll() {
} }
receiver.InjectUDPPacket(p) receiver.InjectUDPPacket(p)
r.Unlock() r.Unlock()
p.Release()
} }
} }
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock // This is an internal router function, the caller must hold the lock
func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { if newAddr, ok := r.outNat[outNatKey{from: fromAddr, to: toAddr}]; ok {
p.From = newAddr p.From = newAddr
} }
c, ok := r.inNat[toAddr] c, ok := r.inNat[toAddr]
if ok { if ok {
r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr r.outNat[outNatKey{from: c.GetUDPAddr(), to: fromAddr}] = toAddr
return c return c
} }
+125
View File
@@ -0,0 +1,125 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"net"
"strings"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
func TestSSHDLifecycle(t *testing.T) {
// TestSSHDLifecycle exercises the in-process sshd through several config reloads and a Control.Stop.
ca, _, caKey, _ := cert_test.NewTestCaCert(
cert.Version1, cert.Curve_CURVE25519,
time.Now(), time.Now().Add(10*time.Minute),
nil, nil, []string{},
)
hostKeyPEM := generateSSHHostKey(t)
clientSigner, clientAuthKey := generateSSHClientKey(t)
sshdAddr := allocLoopbackPort(t)
overrides := m{
"sshd": m{
"enabled": true,
"listen": sshdAddr,
"host_key": hostKeyPEM,
"authorized_users": []m{{
"user": "tester",
"keys": []string{clientAuthKey},
}},
},
}
control, _, _, _ := newSimpleServer(cert.Version1, ca, caKey, "sshd-test", "10.222.0.1/24", overrides)
control.Start()
t.Cleanup(func() { control.Stop() })
// sshd binds in a goroutine after Start returns; wait for it.
require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
"sshd never started listening")
for i := 1; i <= 3; i++ {
out := sshExecReload(t, sshdAddr, clientSigner)
assert.Contains(t, out, "Reloading config", "reload cycle %d", i)
require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
"sshd not listening after reload cycle %d", i)
}
control.Stop()
require.Eventually(t, func() bool { return !canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
"sshd still listening after Control.Stop")
}
func canDial(addr string) bool {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err != nil {
return false
}
_ = c.Close()
return true
}
// allocLoopbackPort grabs an unused TCP port on 127.0.0.1, closes it, and returns the address. There
// is a small race between releasing the port and the sshd reclaiming it; in practice the OS keeps the
// port available long enough for the test to bind it.
func allocLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func generateSSHHostKey(t *testing.T) string {
t.Helper()
_, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
block, err := ssh.MarshalPrivateKey(priv, "nebula-e2e-host")
require.NoError(t, err)
return string(pem.EncodeToMemory(block))
}
func generateSSHClientKey(t *testing.T) (ssh.Signer, string) {
t.Helper()
_, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
signer, err := ssh.NewSignerFromKey(priv)
require.NoError(t, err)
auth := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey())))
return signer, auth
}
func sshExecReload(t *testing.T, addr string, signer ssh.Signer) string {
t.Helper()
cfg := &ssh.ClientConfig{
User: "tester",
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
client, err := ssh.Dial("tcp", addr, cfg)
require.NoError(t, err)
defer client.Close()
sess, err := client.NewSession()
require.NoError(t, err)
defer sess.Close()
// reload tears the channel down before sending exit-status, so Output returns an error on the
// channel close. The output buffer still has whatever the reload callback wrote before that.
out, _ := sess.Output("reload")
return string(out)
}
+2 -2
View File
@@ -355,14 +355,14 @@ func TestCrossStackRelaysWork(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")) myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?") t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")) theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")))
p = r.RouteForAllUntilTxTun(myControl) p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
+30
View File
@@ -138,6 +138,14 @@ listen:
# max, net.core.rmem_max and net.core.wmem_max # max, net.core.rmem_max and net.core.wmem_max
#read_buffer: 10485760 #read_buffer: 10485760
#write_buffer: 10485760 #write_buffer: 10485760
# On Windows only
# When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to UDP at the listener port.
# WFP sits below Windows Defender Firewall, so this lets peer handshakes reach Nebula's outside socket regardless
# of WDF's inbound rules.
# Default true; set to false to leave WDF in charge of inbound decisions on the listener port. Not reloadable.
#windows_bypass_wdf: true
# By default, Nebula replies to packets it has no tunnel for with a "recv_error" packet. This packet helps speed up reconnection # By default, Nebula replies to packets it has no tunnel for with a "recv_error" packet. This packet helps speed up reconnection
# in the case that Nebula on either side did not shut down cleanly. This response can be abused as a way to discover if Nebula is running # in the case that Nebula on either side did not shut down cleanly. This response can be abused as a way to discover if Nebula is running
# on a host though. This option lets you configure if you want to send "recv_error" packets always, never, or only to private network remotes. # on a host though. This option lets you configure if you want to send "recv_error" packets always, never, or only to private network remotes.
@@ -163,17 +171,21 @@ listen:
punchy: punchy:
# Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings # Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
# This setting is reloadable.
punch: true punch: true
# respond means that a node you are trying to reach will connect back out to you if your hole punching fails # respond means that a node you are trying to reach will connect back out to you if your hole punching fails
# this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT # this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT
# Default is false # Default is false
# This setting is reloadable.
#respond: true #respond: true
# delays a punch response for misbehaving NATs, default is 1 second. # delays a punch response for misbehaving NATs, default is 1 second.
# This setting is reloadable.
#delay: 1s #delay: 1s
# set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect. # set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect.
# This setting is reloadable.
#respond_delay: 5s #respond_delay: 5s
# Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes # Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes
@@ -282,6 +294,24 @@ tun:
# metric: 100 # metric: 100
# install: true # install: true
# On Windows only, sets the network category of the nebula interface. Without this, Windows often
# leaves the network as "Unidentified" and treats it as Public, which makes the host firewall more
# restrictive than you usually want for an overlay between trusted peers. Valid values:
# private - treat the nebula network as a private/trusted network (default)
# public - treat it as a public/untrusted network
# domain - treat it as a domain-authenticated network
# unset - leave whatever Windows decided alone
# Not reloadable.
#network_category: private
# On Windows only
# When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to the nebula adapter LUID.
# WFP sits below Windows Defender Firewall, so this lets inbound traffic through regardless of WDF rules.
# Filters are auto-removed when the adapter goes away.
# See listen.windows_bypass_wdf for the matching control over inbound to nebula's outside UDP listener.
# Default true; set to false to leave WDF in charge of inbound decisions on the nebula interface. Not reloadable.
#windows_bypass_wdf: true
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
# in nebula configuration files. Default false, not reloadable. # in nebula configuration files. Default false, not reloadable.
#use_system_route_table: false #use_system_route_table: false
+34 -17
View File
@@ -59,7 +59,8 @@ type Firewall struct {
// assignedNetworks is a list of vpn networks assigned to us in the certificate. // assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix assignedNetworks []netip.Prefix
hasUnsafeNetworks bool // unsafeNetworks is the list of unsafe networks issued to us in the certificate
unsafeNetworks []netip.Prefix
rules string rules string
rulesVersion uint16 rulesVersion uint16
@@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
assignedNetworks = append(assignedNetworks, network) assignedNetworks = append(assignedNetworks, network)
} }
hasUnsafeNetworks := false unsafeNetworks := c.UnsafeNetworks()
for _, n := range c.UnsafeNetworks() { for _, n := range unsafeNetworks {
routableNetworks.Insert(n) routableNetworks.Insert(n)
hasUnsafeNetworks = true
} }
return &Firewall{ return &Firewall{
@@ -176,7 +176,7 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
routableNetworks: routableNetworks, routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks, assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks, unsafeNetworks: unsafeNetworks,
l: l, l: l,
incomingMetrics: firewallMetrics{ incomingMetrics: firewallMetrics{
@@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
} }
if localCidr == "" { if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny {
flc.Any = true flc.Any = true
return nil return nil
} }
@@ -1055,7 +1055,6 @@ func (r *rule) sanity() error {
} }
func parsePort(s string) (int32, int32, error) { func parsePort(s string) (int32, int32, error) {
var err error
const notAPort int32 = -2 const notAPort int32 = -2
if s == "any" { if s == "any" {
return firewall.PortAny, firewall.PortAny, nil return firewall.PortAny, firewall.PortAny, nil
@@ -1064,11 +1063,11 @@ func parsePort(s string) (int32, int32, error) {
return firewall.PortFragment, firewall.PortFragment, nil return firewall.PortFragment, firewall.PortFragment, nil
} }
if !strings.Contains(s, `-`) { if !strings.Contains(s, `-`) {
rPort, err := strconv.Atoi(s) rPort, err := parsePortValue("", s)
if err != nil { if err != nil {
return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s) return notAPort, notAPort, err
} }
return int32(rPort), int32(rPort), nil return rPort, rPort, nil
} }
sPorts := strings.SplitN(s, `-`, 2) sPorts := strings.SplitN(s, `-`, 2)
@@ -1079,22 +1078,40 @@ func parsePort(s string) (int32, int32, error) {
return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
} }
rStartPort, err := strconv.Atoi(sPorts[0]) startPort, err := parsePortValue("beginning range ", sPorts[0])
if err != nil { if err != nil {
return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) return notAPort, notAPort, err
} }
rEndPort, err := strconv.Atoi(sPorts[1]) endPort, err := parsePortValue("ending range ", sPorts[1])
if err != nil { if err != nil {
return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) return notAPort, notAPort, err
} }
startPort := int32(rStartPort)
endPort := int32(rEndPort)
if startPort == firewall.PortAny { if startPort == firewall.PortAny {
endPort = firewall.PortAny endPort = firewall.PortAny
} }
return startPort, endPort, nil return startPort, endPort, nil
} }
// parsePortValue accepts a base-10 decimal in [0, 65535] and returns it
// widened to int32. Using strconv.ParseUint with bitSize 16 rejects
// negative input, out-of-range input (>65535), and any non-decimal byte
// by construction, so the int32 widening that follows is provably safe
// and cannot collide with firewall.PortAny (0) or firewall.PortFragment
// (-1) via integer truncation.
//
// prefix is prepended to both error messages so callers can disambiguate
// the single-port path (prefix="") from the range bounds (prefix="beginning
// range " / "ending range "), preserving the historical error strings.
func parsePortValue(prefix, s string) (int32, error) {
n, err := strconv.ParseUint(s, 10, 16)
if err == nil {
return int32(n), nil
}
if errors.Is(err, strconv.ErrRange) {
return 0, fmt.Errorf("%sout of range [0,65535]; `%s`", prefix, s)
}
return 0, fmt.Errorf("%swas not a number; `%s`", prefix, s)
}
+69
View File
@@ -1029,6 +1029,75 @@ func Test_parsePort(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
// Test_parsePort_invalid covers inputs that must error. The named bug is
// that int32(strconv.Atoi("4294967296")) truncates to 0 == firewall.PortAny,
// silently turning a typo into a match-all-ports rule; the rest are
// representative syntax/range probes.
func Test_parsePort_invalid(t *testing.T) {
tests := []struct {
name string
input string
wantErrContains string
}{
// Numeric overflow (the named bug + boundary).
{"named bug: 2^32 truncates to PortAny", "4294967296", "out of range"},
{"just above max real port", "65536", "out of range"},
// Negatives route through the range branch and hit the empty-half
// guard; included as defense in depth so a future refactor cannot
// accidentally reach the int32 cast.
{"negative", "-1", "could not be parsed"},
// Syntax probes.
{"NUL between digits", "4\x002", "was not a number"},
{"hex notation", "0x10", "was not a number"},
{"scientific notation", "1e3", "was not a number"},
{"leading whitespace", " 42", "was not a number"},
{"fullwidth digits", "42", "was not a number"},
// Range branch.
{"range upper out of range", "1-65536", "ending range out of range"},
{"range lower out of range", "65536-65537", "beginning range out of range"},
{"range with negative upper", "1--1", "ending range was not a number"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, _, err := parsePort(tc.input)
require.Error(t, err, "input %q must error", tc.input)
require.ErrorContains(t, err, tc.wantErrContains)
})
}
}
// Test_parsePort_valid_boundaries locks in success cases at 0, 1, and 65535
// so a future refactor cannot regress the boundaries.
func Test_parsePort_valid_boundaries(t *testing.T) {
tests := []struct {
name string
input string
wantStart int32
wantEnd int32
}{
{"zero is PortAny", "0", 0, 0},
{"min real port", "1", 1, 1},
{"max real port", "65535", 65535, 65535},
{"range zero to max forces end to zero", "0-65535", 0, 0},
{"range max to max", "65535-65535", 65535, 65535},
{"range one to max", "1-65535", 1, 65535},
{"range with whitespace inside", " 1 - 2 ", 1, 2},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
s, e, err := parsePort(tc.input)
require.NoError(t, err)
assert.Equal(t, tc.wantStart, s, "start port")
assert.Equal(t, tc.wantEnd, e, "end port")
})
}
}
func TestNewFirewallFromConfig(t *testing.T) { func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
// Test a bad rule definition // Test a bad rule definition
+5 -5
View File
@@ -9,7 +9,7 @@ require (
github.com/armon/go-radix v1.0.0 github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0 github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.26.0 github.com/gaissmai/bart v0.27.1
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4 github.com/kardianos/service v1.2.4
@@ -24,12 +24,12 @@ require (
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
go.uber.org/goleak v1.3.0 go.uber.org/goleak v1.3.0
go.yaml.in/yaml/v3 v3.0.4 go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.50.0 golang.org/x/crypto v0.51.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.52.0 golang.org/x/net v0.54.0
golang.org/x/sync v0.20.0 golang.org/x/sync v0.20.0
golang.org/x/sys v0.43.0 golang.org/x/sys v0.44.0
golang.org/x/term v0.42.0 golang.org/x/term v0.43.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.6.1 golang.zx2c4.com/wireguard/windows v0.6.1
+10 -10
View File
@@ -26,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0= github.com/gaissmai/bart v0.27.1 h1:FysPzqETMJa8q9rNkLW5peT1hq25nLOz8ksHbSVoiAk=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/gaissmai/bart v0.27.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+2
View File
@@ -31,6 +31,7 @@ type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
type Result struct { type Result struct {
EKey *noise.CipherState EKey *noise.CipherState
DKey *noise.CipherState DKey *noise.CipherState
Cipher noise.CipherFunc // identifies which post-handshake CipherState the data plane should wrap EKey/DKey in
MyCert cert.Certificate MyCert cert.Certificate
RemoteCert *cert.CachedCertificate RemoteCert *cert.CachedCertificate
RemoteIndex uint32 RemoteIndex uint32
@@ -105,6 +106,7 @@ func NewMachine(
myVersion: version, myVersion: version,
result: &Result{ result: &Result{
Initiator: initiator, Initiator: initiator,
Cipher: cred.cipherSuite,
}, },
}, nil }, nil
} }
+2 -146
View File
@@ -23,7 +23,6 @@ const (
DefaultHandshakeTryInterval = time.Millisecond * 100 DefaultHandshakeTryInterval = time.Millisecond * 100
DefaultHandshakeRetries = 10 DefaultHandshakeRetries = 10
DefaultHandshakeTriggerBuffer = 64 DefaultHandshakeTriggerBuffer = 64
DefaultUseRelays = true
// maxCachedPackets is how many unsent packets we'll buffer per pending // maxCachedPackets is how many unsent packets we'll buffer per pending
// handshake before dropping further ones. // handshake before dropping further ones.
@@ -43,7 +42,6 @@ var (
tryInterval: DefaultHandshakeTryInterval, tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries, retries: DefaultHandshakeRetries,
triggerBuffer: DefaultHandshakeTriggerBuffer, triggerBuffer: DefaultHandshakeTriggerBuffer,
useRelays: DefaultUseRelays,
} }
) )
@@ -51,7 +49,6 @@ type HandshakeConfig struct {
tryInterval time.Duration tryInterval time.Duration
retries int64 retries int64
triggerBuffer int triggerBuffer int
useRelays bool
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
} }
@@ -86,6 +83,7 @@ type HandshakeHostInfo struct {
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake? initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
counter int64 // How many attempts have we made so far counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
lastRelays []netip.Addr // Relays we attempted to use during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo hostinfo *HostInfo
@@ -220,7 +218,6 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
fields := []any{ fields := []any{
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
"initiatorIndex", hh.hostinfo.localIndexId, "initiatorIndex", hh.hostinfo.localIndexId,
"remoteIndex", hh.hostinfo.remoteIndexId,
"durationNs", time.Since(hh.startTime).Nanoseconds(), "durationNs", time.Since(hh.startTime).Nanoseconds(),
} }
// hh.machine can be nil here if buildStage0Packet never succeeded // hh.machine can be nil here if buildStage0Packet never succeeded
@@ -326,146 +323,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
) )
} }
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { hm.f.relayManager.StartRelays(hm.f, vpnIp, hh, stage0)
hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to
if relay == vpnIp {
continue
}
// Don't relay to myself
if hm.f.myVpnAddrsTable.Contains(relay) {
continue
}
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String())
hm.f.Handshake(relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through
existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
if !ok {
// No relays exist or requested yet.
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
}
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.Info("send CreateRelayRequest",
"relayFrom", hm.f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", idx,
"relay", relay,
)
}
}
continue
}
switch existingRelay.State {
case Established:
hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[handshakePacketStage0], make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested:
hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String())
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.Info("send CreateRelayRequest",
"relayFrom", hm.f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", existingRelay.LocalIndex,
"relay", relay,
)
}
case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default:
hostinfo.logger(hm.l).Error("Relay unexpected state",
"vpnIp", vpnIp,
"state", existingRelay.State,
"relay", relay,
)
}
}
}
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered { if !lighthouseTriggered {
@@ -607,7 +465,6 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs, "collision", existingRemoteIndex.vpnAddrs,
) )
} }
@@ -630,7 +487,6 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs, "collision", existingRemoteIndex.vpnAddrs,
) )
} }
+14
View File
@@ -174,6 +174,10 @@ func (h *H) SubTypeName() string {
return SubTypeName(h.Type, h.Subtype) return SubTypeName(h.Type, h.Subtype)
} }
func (h *H) IsValidSubType() bool {
return IsValidSubType(h.Type, h.Subtype)
}
// SubTypeName will transform a nebula message sub type into a human string // SubTypeName will transform a nebula message sub type into a human string
func SubTypeName(t MessageType, s MessageSubType) string { func SubTypeName(t MessageType, s MessageSubType) string {
if n, ok := subTypeMap[t]; ok { if n, ok := subTypeMap[t]; ok {
@@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string {
return "unknown" return "unknown"
} }
func IsValidSubType(t MessageType, s MessageSubType) bool {
if n, ok := subTypeMap[t]; ok {
if _, ok := (*n)[s]; ok {
return true
}
}
return false
}
// NewHeader turns bytes into a header // NewHeader turns bytes into a header
func NewHeader(b []byte) (*H, error) { func NewHeader(b []byte) (*H, error) {
h := new(H) h := new(H)
-1
View File
@@ -391,7 +391,6 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
"error", err, "error", err,
"udpAddr", remote, "udpAddr", remote,
"counter", c, "counter", c,
"attemptedCounter", c,
) )
return return
} }
+27 -8
View File
@@ -7,6 +7,7 @@ import (
"io" "io"
"log/slog" "log/slog"
"net/netip" "net/netip"
"slices"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -14,6 +15,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
} }
func (f *Interface) reloadFirewall(c *config.C) { func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too cs := f.pki.getCertState()
if c.HasChanged("firewall") == false { curCert := cs.getCertificate(cert.Version2)
if curCert == nil {
curCert = cs.getCertificate(cert.Version1)
}
// The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction.
// Check to see if that set has changed, and if so, rebuild the firewall.
certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks)
if !c.HasChanged("firewall") && !certUnsafeChanged {
f.l.Debug("No firewall config change detected") f.l.Debug("No firewall config change detected")
return return
} }
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) fw, err := NewFirewallFromConfig(f.l, cs, c)
if err != nil { if err != nil {
f.l.Error("Error while creating firewall during reload", "error", err) f.l.Error("Error while creating firewall during reload", "error", err)
return return
@@ -491,11 +502,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil) certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil) certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
for { emit := func() {
select {
case <-ctx.Done():
return
case <-ticker.C:
f.firewall.EmitStats() f.firewall.EmitStats()
f.handshakeManager.EmitStats() f.handshakeManager.EmitStats()
udpStats() udpStats()
@@ -512,6 +519,18 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
certMaxVersion.Update(int64(certState.v1Cert.Version())) certMaxVersion.Update(int64(certState.v1Cert.Version()))
} }
} }
// Prime gauges so a Prometheus scrape that lands before the first tick
// sees real values instead of the zero defaults (issue #907).
emit()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
emit()
}
} }
} }
+73
View File
@@ -0,0 +1,73 @@
//go:build linux || darwin
package nebula
import (
"context"
"net/netip"
"testing"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/overlay/overlaytest"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that
// landed before the first ticker fire used to read 0 for the cert gauges.
// emitStats now primes the gauges before entering the ticker loop. We assert
// the gauge is zero before the first call and non-zero after.
func Test_emitStats_primesGauges(t *testing.T) {
defer metrics.DefaultRegistry.UnregisterAll()
l := test.NewLogger()
hostMap := newHostMap(l)
preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
hostMap.preferredRanges.Store(&preferredRanges)
notAfter := time.Now().Add(time.Hour)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter},
v1Credential: nil,
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
// On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to
// *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn,
// returns an error, and the emitter falls through to a no-op.
writers: []udp.Conn{&udp.StdConn{}},
}
ifce.pki.cs.Store(cs)
ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs")
// Pre-cancel the context so emitStats returns after priming the gauges
// without ever reading from ticker.C. The one hour interval is just a
// belt-and-suspenders, the test does not expect the ticker to fire.
ctx, cancel := context.WithCancel(context.Background())
cancel()
ifce.emitStats(ctx, time.Hour)
ttl := ttlGauge.Value()
assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick")
assert.LessOrEqual(t, ttl, int64(3600))
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value())
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value())
}
+120
View File
@@ -0,0 +1,120 @@
package nebula
import (
"net/netip"
"testing"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall
// rebuilds the firewall when only the certificate's UnsafeNetworks have changed,
// even if the firewall section of the YAML has not.
func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) {
l := test.NewLogger()
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
// dummyCert avoids dragging the real signing pipeline into a unit test.
c1 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: initialUnsafe,
}
pki := &PKI{}
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
rawYAML := `firewall:
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
`
cfg := config.NewC(l)
require.NoError(t, cfg.LoadString(rawYAML))
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
require.NoError(t, err)
require.Equal(t, initialUnsafe, fw.unsafeNetworks)
f := &Interface{
pki: pki,
firewall: fw,
l: l,
}
// Swap the cert with a different UnsafeNetworks set.
newUnsafe := []netip.Prefix{
netip.MustParsePrefix("198.51.100.0/24"),
netip.MustParsePrefix("203.0.113.0/24"),
}
c2 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: newUnsafe,
}
pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2})
// Reload with the same YAML so HasChanged("firewall") reports false.
require.NoError(t, cfg.ReloadConfigString(rawYAML))
require.False(t, cfg.HasChanged("firewall"))
f.reloadFirewall(cfg)
assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced")
assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks)
assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5")))
}
// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when
// neither the firewall config nor the cert's UnsafeNetworks have changed.
func TestReloadFirewall_NoChange(t *testing.T) {
l := test.NewLogger()
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
c1 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: unsafe,
}
pki := &PKI{}
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
rawYAML := `firewall:
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
`
cfg := config.NewC(l)
require.NoError(t, cfg.LoadString(rawYAML))
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
require.NoError(t, err)
f := &Interface{
pki: pki,
firewall: fw,
l: l,
}
require.NoError(t, cfg.ReloadConfigString(rawYAML))
f.reloadFirewall(cfg)
assert.Same(t, fw, f.firewall, "firewall should not have been replaced")
}
+25 -48
View File
@@ -15,7 +15,6 @@ import (
"time" "time"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@@ -35,7 +34,6 @@ type LightHouse struct {
myVpnNetworks []netip.Prefix myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Lite myVpnNetworksTable *bart.Lite
punchConn udp.Conn
punchy *Punchy punchy *Punchy
// Local cache of answers from light houses // Local cache of answers from light houses
@@ -76,7 +74,6 @@ type LightHouse struct {
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
metrics *MessageMetrics metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *slog.Logger l *slog.Logger
} }
@@ -105,7 +102,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
myVpnNetworksTable: cs.myVpnNetworksTable, myVpnNetworksTable: cs.myVpnNetworksTable,
addrMap: make(map[netip.Addr]*RemoteList), addrMap: make(map[netip.Addr]*RemoteList),
nebulaPort: nebulaPort, nebulaPort: nebulaPort,
punchConn: pc,
punchy: p, punchy: p,
updateTrigger: make(chan struct{}, 1), updateTrigger: make(chan struct{}, 1),
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
@@ -118,9 +114,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
if c.GetBool("stats.lighthouse_metrics", false) { if c.GetBool("stats.lighthouse_metrics", false) {
h.metrics = newLighthouseMetrics() h.metrics = newLighthouseMetrics()
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
} else {
h.metricHolepunchTx = metrics.NilCounter{}
} }
err := h.reload(c, true) err := h.reload(c, true)
@@ -279,16 +272,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") { if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
// Clean up. Entries still in the static_host_map will be re-built. // Clean up. Entries still in the static_host_map will be re-built.
// Entries no longer present must have their (possible) background DNS goroutines stopped. ourselves := lh.myVpnNetworks[0].Addr()
if existingStaticList := lh.staticList.Load(); existingStaticList != nil { oldStaticList := lh.staticList.Load()
if oldStaticList != nil {
lh.RLock() lh.RLock()
for staticVpnAddr := range *existingStaticList { for staticVpnAddr := range *oldStaticList {
if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil {
am.hr.Cancel() am.ResetForOwner(ourselves)
} }
} }
lh.RUnlock() lh.RUnlock()
} }
// Build a new list based on current config. // Build a new list based on current config.
staticList := make(map[netip.Addr]struct{}) staticList := make(map[netip.Addr]struct{})
err := lh.loadStaticMap(c, staticList) err := lh.loadStaticMap(c, staticList)
@@ -296,6 +291,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return err return err
} }
// For entries removed from static_host_map, stop the DNS goroutine and drop the cached addrs.
// All addrs must come from the lighthouses now that it's no longer a static host.
if oldStaticList != nil {
lh.RLock()
for staticVpnAddr := range *oldStaticList {
if _, stillStatic := staticList[staticVpnAddr]; stillStatic {
continue
}
if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil {
am.ClearHostnameResults()
}
}
lh.RUnlock()
}
lh.staticList.Store(&staticList) lh.staticList.Store(&staticList)
if !initial { if !initial {
if c.HasChanged("static_host_map") { if c.HasChanged("static_host_map") {
@@ -1406,58 +1416,25 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
return return
} }
empty := []byte{0}
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
if !vpnPeer.IsValid() {
return
}
go func() {
time.Sleep(lhh.lh.punchy.GetDelay())
lhh.lh.metricHolepunchTx.Inc(1)
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
}()
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Punching",
"vpnPeer", vpnPeer,
"logVpnAddr", logVpnAddr,
)
}
}
remoteAllowList := lhh.lh.GetRemoteAllowList() remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts { for _, a := range n.Details.V4AddrPorts {
b := protoV4AddrPortToNetAddrPort(a) b := protoV4AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr) lhh.lh.punchy.Schedule(b, detailsVpnAddr)
} }
} }
for _, a := range n.Details.V6AddrPorts { for _, a := range n.Details.V6AddrPorts {
b := protoV6AddrPortToNetAddrPort(a) b := protoV6AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr) lhh.lh.punchy.Schedule(b, detailsVpnAddr)
} }
} }
// This sends a nebula test packet to the host trying to contact us. In the case // This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish // of a double nat or other difficult scenario, this may help establish
// a tunnel. // a tunnel. ScheduleRespond is a no-op when punchy.respond is disabled.
if lhh.lh.punchy.GetRespond() { lhh.lh.punchy.ScheduleRespond(detailsVpnAddr)
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Sending a nebula test packet",
"vpnAddr", detailsVpnAddr,
)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
}
} }
func protoAddrToNetAddr(addr *Addr) netip.Addr { func protoAddrToNetAddr(addr *Addr) netip.Addr {
+126
View File
@@ -303,6 +303,132 @@ func TestLighthouse_reload(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
// TestLighthouse_reloadStaticHostMap verifies that reloading static_host_map applies the new
// config rather than appending to it. See issue #718.
func TestLighthouse_reloadStaticHostMap(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
"10.128.0.2": []any{"1.1.1.1:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
require.NoError(t, err)
staticHost := netip.MustParseAddr("10.128.0.2")
otherHost := netip.MustParseAddr("10.128.0.3")
// Capture the RemoteList pointer up front; an in-flight handshake would hold the same one
// on hostinfo.remotes, so it must reflect every reload below.
pinned := lh.Query(staticHost)
require.NotNil(t, pinned)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, pinned.CopyAddrs([]netip.Prefix{}))
// Replace the remote address. The new address should be the only entry.
nc := map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"2.2.2.2:4242"},
},
}
rc, err := yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
rl := lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl, "RemoteList pointer must stay stable so in-flight handshakes pick up the change")
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Reload back to the original IP. Mirrors the round-trip in issue #718 step 6-8 where
// the buggy reload produced [1.1.1.1, 2.2.2.2, 1.1.1.1] instead of [1.1.1.1].
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"1.1.1.1:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Reload with the same config. An unchanged entry must not duplicate.
require.NoError(t, c.ReloadConfigString(string(rc)))
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Switch back to 2.2.2.2 so the rest of the test continues against a known address.
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"2.2.2.2:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
// Add a second host alongside the first. Both should be present, neither duplicated.
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"2.2.2.2:4242"},
"10.128.0.3": []any{"3.3.3.3:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl, "adding a sibling entry must not displace the existing RemoteList")
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{}))
rl = lh.Query(otherHost)
require.NotNil(t, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Drop the first host entirely. The vpnAddr is no longer marked static, our owner
// contribution is cleared, but the addrMap entry stays in place so non-static cache
// data (from lighthouse queries) on the same RemoteList isn't lost. In-flight handshakes
// that already had the pointer see an empty address list rather than retrying stale ones.
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.3": []any{"3.3.3.3:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
_, isStatic := lh.GetStaticHostList()[staticHost]
assert.False(t, isStatic)
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl)
assert.Empty(t, rl.CopyAddrs([]netip.Prefix{}))
rl = lh.Query(otherHost)
require.NotNil(t, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{}))
}
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
+5 -7
View File
@@ -55,7 +55,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
} }
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes()) l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(ctx, l.With("subsystem", "sshd"))
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
} }
@@ -170,7 +170,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
} }
hostMap := NewHostMapFromConfig(l, c) hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c) punchy := NewPunchyFromConfig(l, c, udpConns[0])
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy) connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil { if err != nil {
@@ -184,21 +184,17 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
messageMetrics = newMessageMetricsOnlyRecvError() messageMetrics = newMessageMetricsOnlyRecvError()
} }
useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false)
handshakeConfig := HandshakeConfig{ handshakeConfig := HandshakeConfig{
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
useRelays: useRelays,
messageMetrics: messageMetrics, messageMetrics: messageMetrics,
} }
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger lightHouse.handshakeTrigger = handshakeManager.trigger
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) ds, err := newDnsServerFromConfig(ctx, l, pki, hostMap, c)
if err != nil { if err != nil {
l.Warn("Failed to start DNS responder", "error", err) l.Warn("Failed to start DNS responder", "error", err)
} }
@@ -244,6 +240,8 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
handshakeManager.f = ifce handshakeManager.f = ifce
go handshakeManager.Run(ctx) go handshakeManager.Run(ctx)
punchy.Start(ctx, ifce, hostMap, lightHouse)
} }
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest) stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
+8
View File
@@ -13,6 +13,8 @@ type MessageMetrics struct {
rxUnknown metrics.Counter rxUnknown metrics.Counter
txUnknown metrics.Counter txUnknown metrics.Counter
rxInvalid metrics.Counter
} }
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
@@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int
} }
} }
} }
func (m *MessageMetrics) RxInvalid(i int64) {
if m != nil && m.rxInvalid != nil {
m.rxInvalid.Inc(i)
}
}
func newMessageMetrics() *MessageMetrics { func newMessageMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter { gen := func(t string) [][]metrics.Counter {
@@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics {
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil), rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil), txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil),
} }
} }
-88
View File
@@ -1,88 +0,0 @@
package nebula
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
type endianness interface {
PutUint64(b []byte, v uint64)
}
var noiseEndianness endianness = binary.BigEndian
type NebulaCipherState struct {
c cipher.AEAD
}
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
x := s.Cipher()
return &NebulaCipherState{c: x.(cipher.AEAD)}
}
type cipherAEADDanger interface {
EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error)
DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error)
}
// EncryptDanger encrypts and authenticates a given payload.
//
// out is a destination slice to hold the output of the EncryptDanger operation.
// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
// - plaintext is encrypted, authenticated and appended to out.
// - n is a nonce value which must never be re-used with this key.
// - nb is a buffer used for temporary storage in the implementation of this call, which should
// be re-used by callers to minimize garbage collection.
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
switch ce := s.c.(type) {
case cipherAEADDanger:
return ce.EncryptDanger(out, ad, plaintext, n, nb)
default:
// TODO: Is this okay now that we have made messageCounter atomic?
// Alternative may be to split the counter space into ranges
//if n <= s.n {
// return nil, errors.New("CRITICAL: a duplicate counter value was used")
//}
//s.n = n
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
out = s.c.Seal(out, nb, plaintext, ad)
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
return out, nil
}
} else {
return nil, errors.New("no cipher state available to encrypt")
}
}
func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
switch ce := s.c.(type) {
case cipherAEADDanger:
return ce.DecryptDanger(out, ad, ciphertext, n, nb)
default:
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
return s.c.Open(out, nb, ciphertext, ad)
}
} else {
return []byte{}, nil
}
}
func (s *NebulaCipherState) Overhead() int {
if s != nil {
return s.c.Overhead()
}
return 0
}
+53
View File
@@ -0,0 +1,53 @@
package noiseutil
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
// CipherStateAESGCM is the data-plane wrapper for the AES-GCM AEAD cipher.
// AES-GCM uses big-endian nonce encoding per the Noise spec.
type CipherStateAESGCM struct {
c cipher.AEAD
}
// NewCipherStateAESGCM extracts the underlying AEAD from the post-handshake noise.CipherState.
// The caller is responsible for ensuring the noise cipher is actually AES-GCM,
// otherwise the type assertion still succeeds but the nonce endianness will be wrong on the wire.
func NewCipherStateAESGCM(s *noise.CipherState) *CipherStateAESGCM {
return &CipherStateAESGCM{c: s.Cipher().(cipher.AEAD)}
}
func (s *CipherStateAESGCM) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return nil, errors.New("no cipher state available to encrypt")
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.BigEndian.PutUint64(nb[4:], n)
return s.c.Seal(out, nb, plaintext, ad), nil
}
func (s *CipherStateAESGCM) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return []byte{}, nil
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.BigEndian.PutUint64(nb[4:], n)
return s.c.Open(out, nb, ciphertext, ad)
}
func (s *CipherStateAESGCM) Overhead() int {
if s == nil {
return 0
}
return s.c.Overhead()
}
+52
View File
@@ -0,0 +1,52 @@
package noiseutil
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
// CipherStateChaChaPoly is the data-plane wrapper for the ChaCha20-Poly1305 AEAD cipher.
// ChaCha20-Poly1305 uses little-endian nonce encoding per the Noise spec.
type CipherStateChaChaPoly struct {
c cipher.AEAD
}
// NewCipherStateChaChaPoly extracts the underlying AEAD from the post-handshake noise.CipherState.
// The caller is responsible for ensuring the noise cipher is actually ChaCha20-Poly1305.
func NewCipherStateChaChaPoly(s *noise.CipherState) *CipherStateChaChaPoly {
return &CipherStateChaChaPoly{c: s.Cipher().(cipher.AEAD)}
}
func (s *CipherStateChaChaPoly) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return nil, errors.New("no cipher state available to encrypt")
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.LittleEndian.PutUint64(nb[4:], n)
return s.c.Seal(out, nb, plaintext, ad), nil
}
func (s *CipherStateChaChaPoly) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return []byte{}, nil
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.LittleEndian.PutUint64(nb[4:], n)
return s.c.Open(out, nb, ciphertext, ad)
}
func (s *CipherStateChaChaPoly) Overhead() int {
if s == nil {
return 0
}
return s.c.Overhead()
}
+40
View File
@@ -0,0 +1,40 @@
package noiseutil
import (
"fmt"
"github.com/flynn/noise"
)
// CipherState is the post-handshake AEAD cipher used for the data plane.
// Each supported cipher has its own concrete implementation in this package with the nonce endianness hardcoded,
// so the encrypt/decrypt fast path avoids interface dispatch on the byte order.
type CipherState interface {
// EncryptDanger encrypts and authenticates a given payload.
//
// out is a destination slice to hold the output of the EncryptDanger operation.
// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
// - plaintext is encrypted, authenticated and appended to out.
// - n is a nonce value which must never be re-used with this key.
// - nb is a scratch buffer used to assemble the nonce.
EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error)
// DecryptDanger authenticates and decrypts a given payload, with the same argument shape as EncryptDanger.
DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error)
// Overhead returns the AEAD tag size, or 0 if the receiver is nil.
Overhead() int
}
// NewCipherState wraps the post-handshake noise.CipherState in the per-cipher type that matches cipherFunc.
// cipherFunc must be the same cipher used to build the noise CipherSuite that produced s.
func NewCipherState(s *noise.CipherState, cipherFunc noise.CipherFunc) CipherState {
switch cipherFunc.CipherName() {
case CipherAESGCM.CipherName():
return NewCipherStateAESGCM(s)
case noise.CipherChaChaPoly.CipherName():
return NewCipherStateChaChaPoly(s)
default:
panic(fmt.Sprintf("noiseutil: unsupported cipher %q", cipherFunc.CipherName()))
}
}
+166
View File
@@ -0,0 +1,166 @@
package noiseutil
import (
"testing"
"github.com/flynn/noise"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCipherStateAESGCMRoundtrip(t *testing.T) {
enc, dec := buildCipherStates(t, CipherAESGCM)
roundtrip(t, NewCipherStateAESGCM(enc), NewCipherStateAESGCM(dec))
}
func TestCipherStateChaChaPolyRoundtrip(t *testing.T) {
enc, dec := buildCipherStates(t, noise.CipherChaChaPoly)
roundtrip(t, NewCipherStateChaChaPoly(enc), NewCipherStateChaChaPoly(dec))
}
func TestNewCipherStateDispatch(t *testing.T) {
encA, _ := buildCipherStates(t, CipherAESGCM)
encC, _ := buildCipherStates(t, noise.CipherChaChaPoly)
assert.IsType(t, &CipherStateAESGCM{}, NewCipherState(encA, CipherAESGCM))
assert.IsType(t, &CipherStateChaChaPoly{}, NewCipherState(encC, noise.CipherChaChaPoly))
}
func TestNewCipherStateUnsupportedPanics(t *testing.T) {
enc, _ := buildCipherStates(t, CipherAESGCM)
assert.Panics(t, func() {
NewCipherState(enc, fakeCipher{})
})
}
type fakeCipher struct{}
func (fakeCipher) Cipher(k [32]byte) noise.Cipher { return nil }
func (fakeCipher) CipherName() string { return "Fake" }
// buildCipherStates runs an in-memory NN handshake with the requested cipher
// to produce a pair of post-handshake CipherStates that share keys.
func buildCipherStates(t *testing.T, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) {
t.Helper()
suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256)
cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN}
cfg.Initiator = true
hsI, err := noise.NewHandshakeState(cfg)
require.NoError(t, err)
cfg.Initiator = false
hsR, err := noise.NewHandshakeState(cfg)
require.NoError(t, err)
msg, _, _, err := hsI.WriteMessage(nil, nil)
require.NoError(t, err)
_, _, _, err = hsR.ReadMessage(nil, msg)
require.NoError(t, err)
msg, dR, _, err := hsR.WriteMessage(nil, nil)
require.NoError(t, err)
_, eI, _, err := hsI.ReadMessage(nil, msg)
require.NoError(t, err)
require.NotNil(t, eI)
require.NotNil(t, dR)
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
return eI, dR
}
func roundtrip(t *testing.T, enc, dec CipherState) {
t.Helper()
plaintext := []byte("nebula cipher state roundtrip")
ad := []byte("aad")
nb := make([]byte, 12)
ct, err := enc.EncryptDanger(nil, ad, plaintext, 1, nb)
require.NoError(t, err)
assert.NotEqual(t, plaintext, ct)
pt, err := dec.DecryptDanger(nil, ad, ct, 1, nb)
require.NoError(t, err)
assert.Equal(t, plaintext, pt)
// Wrong nonce must fail authentication.
_, err = dec.DecryptDanger(nil, ad, ct, 2, nb)
require.Error(t, err)
assert.Equal(t, enc.Overhead(), dec.Overhead())
assert.Equal(t, 16, enc.Overhead())
}
func BenchmarkCipherStateEncryptAESGCM(b *testing.B) {
enc, _ := buildCipherStatesB(b, CipherAESGCM)
benchEncryptCipherState(b, NewCipherState(enc, CipherAESGCM))
}
func BenchmarkCipherStateEncryptChaChaPoly(b *testing.B) {
enc, _ := buildCipherStatesB(b, noise.CipherChaChaPoly)
benchEncryptCipherState(b, NewCipherState(enc, noise.CipherChaChaPoly))
}
func benchEncryptCipherState(b *testing.B, cs CipherState) {
plaintext := make([]byte, 1280)
ad := make([]byte, 16)
nb := make([]byte, 12)
out := make([]byte, 0, len(plaintext)+cs.Overhead())
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
var err error
out, err = cs.EncryptDanger(out[:0], ad, plaintext, uint64(i+1), nb)
if err != nil {
b.Fatal(err)
}
}
}
func buildCipherStatesB(b *testing.B, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) {
b.Helper()
suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256)
cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN}
cfg.Initiator = true
hsI, err := noise.NewHandshakeState(cfg)
if err != nil {
b.Fatal(err)
}
cfg.Initiator = false
hsR, err := noise.NewHandshakeState(cfg)
if err != nil {
b.Fatal(err)
}
msg, _, _, err := hsI.WriteMessage(nil, nil)
if err != nil {
b.Fatal(err)
}
if _, _, _, err := hsR.ReadMessage(nil, msg); err != nil {
b.Fatal(err)
}
msg, dR, _, err := hsR.WriteMessage(nil, nil)
if err != nil {
b.Fatal(err)
}
_, eI, _, err := hsI.ReadMessage(nil, msg)
if err != nil {
b.Fatal(err)
}
return eI, dR
}
func TestCipherStateNilSafety(t *testing.T) {
var aes *CipherStateAESGCM
_, err := aes.EncryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.Error(t, err)
out, err := aes.DecryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.NoError(t, err)
assert.Empty(t, out)
assert.Equal(t, 0, aes.Overhead())
var cc *CipherStateChaChaPoly
_, err = cc.EncryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.Error(t, err)
out, err = cc.DecryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.NoError(t, err)
assert.Empty(t, out)
assert.Equal(t, 0, cc.Overhead())
}
+126 -164
View File
@@ -20,23 +20,46 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
var ErrOutOfWindow = errors.New("out of window packet")
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
// TODO: record metrics for rx holepunch/punchy packets?
if len(packet) > 1 { if len(packet) > 1 {
f.l.Info("Error while parsing inbound packet", f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while parsing inbound packet",
"from", via, "from", via,
"error", err, "error", err,
"packet", packet, "packet", packet,
) )
} }
}
return
}
if h.Version != header.Version {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Unexpected header version received", "from", via)
}
return
}
// Check before processing to see if this is a expected type/subtype
if !h.IsValidSubType() {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Unexpected packet received", "from", via)
}
return return
} }
//l.Error("in packet ", header, packet[HeaderLen:])
if !via.IsRelayed { if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Refusing to process double encrypted packet", "from", via) f.l.Debug("Refusing to process double encrypted packet", "from", via)
} }
@@ -44,31 +67,108 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
} }
} }
// don't keep Rx metrics for message type, since you can see those in the tun metrics
if h.Type != header.Message {
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
}
// Unencrypted packets
switch h.Type {
case header.Handshake:
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.handleRecvError(via.UdpAddr, h)
return
}
// Relay packets are special
isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay)
var hostinfo *HostInfo var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation if isMessageRelay {
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else { } else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
} }
var ci *ConnectionState // At this point we should have a valid existing tunnel, verify and send
if hostinfo != nil { // recvError if necessary
ci = hostinfo.ConnectionState if hostinfo == nil || hostinfo.ConnectionState == nil {
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
} }
return
}
// All remaining packets are encrypted
ci := hostinfo.ConnectionState
if !ci.window.Check(f.l, h.MessageCounter) {
return
}
// Relay packets are special
if isMessageRelay {
f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache)
return
}
out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Failed to decrypt packet",
"error", err,
"from", via,
"header", h,
)
}
return
}
// Roam before we respond
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo)
switch h.Type { switch h.Type {
case header.Message: case header.Message:
if !f.handleEncrypted(ci, via, h) { switch h.Subtype {
case header.MessageNone:
f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h)
return return
} }
case header.LightHouse:
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f)
case header.Test:
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.TestReply:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { // No-op, useful for the Roaming and connectionManager side-effects above
case header.TestRequest:
f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h)
return return
} }
case header.MessageRelay:
case header.CloseTunnel:
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
f.closeTunnel(hostinfo)
case header.Control:
f.relayManager.HandleControlMsg(hostinfo, out, f)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h)
}
}
func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
// The entire body is sent as AD, not encrypted. // The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
@@ -76,6 +176,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// which will gracefully fail in the DecryptDanger call. // which will gracefully fail in the DecryptDanger call.
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil { if err != nil {
return return
@@ -93,8 +194,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen. // its internal mapping. This should never happen.
hostinfo.logger(f.l).Error("HostInfo missing remote relay index", hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
"vpnAddrs", hostinfo.vpnAddrs, "relayRemoteIndex", h.RemoteIndex,
"remoteIndex", h.RemoteIndex,
) )
return return
} }
@@ -111,15 +211,14 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
IsRelayed: true, IsRelayed: true,
} }
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil { if err != nil {
hostinfo.logger(f.l).Info("Failed to find target host info by ip", hostinfo.logger(f.l).Info("Failed to find target host info by ip",
"relayTo", relay.PeerAddr, "relayTo", relay.PeerAddr,
"relayFrom", hostinfo.vpnAddrs[0],
"error", err, "error", err,
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
) )
return return
} }
@@ -131,9 +230,14 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// Forward this packet through the relay tunnel // Forward this packet through the relay tunnel
// Find the target HostInfo // Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
return
case TerminalType: case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
return
default:
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type)
}
return
} }
} else { } else {
hostinfo.logger(f.l).Info("Unexpected target relay state", hostinfo.logger(f.l).Info("Unexpected target relay state",
@@ -143,116 +247,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
) )
return return
} }
}
}
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt test packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(via.UdpAddr, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
default: default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via) hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type)
} }
return
} }
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo)
} }
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -300,23 +299,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
} }
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
}
return false
}
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
}
return true
}
var ( var (
ErrPacketTooShort = errors.New("packet is too short") ErrPacketTooShort = errors.New("packet is too short")
ErrUnknownIPVersion = errors.New("packet is an unknown ip version") ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
@@ -523,38 +505,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
} }
if !hostinfo.ConnectionState.window.Update(f.l, mc) { if !hostinfo.ConnectionState.window.Update(f.l, mc) {
if f.l.Enabled(context.Background(), slog.LevelDebug) { return nil, ErrOutOfWindow
hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h)
}
return nil, errors.New("out of window packet")
} }
return out, nil return out, nil
} }
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
var err error err := newPacket(out, true, fwPacket)
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
return false
}
err = newPacket(out, true, fwPacket)
if err != nil { if err != nil {
hostinfo.logger(f.l).Warn("Error while validating inbound packet", hostinfo.logger(f.l).Warn("Error while validating inbound packet",
"error", err, "error", err,
"packet", out, "packet", out,
) )
return false return
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket)
}
return false
} }
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
@@ -568,15 +532,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
"reason", dropReason, "reason", dropReason,
) )
} }
return false return
} }
f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out) _, err = f.readers[q].Write(out)
if err != nil { if err != nil {
f.l.Error("Failed to write to tun", "error", err) f.l.Error("Failed to write to tun", "error", err)
} }
return true
} }
func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
+358
View File
@@ -0,0 +1,358 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"errors"
"fmt"
"log/slog"
"runtime"
"strings"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
// networkCategory mirrors NLM_NETWORK_CATEGORY from netlistmgr.h.
type networkCategory int32
const (
networkCategoryPublic networkCategory = 0
networkCategoryPrivate networkCategory = 1
networkCategoryDomainAuthenticated networkCategory = 2
)
func (c networkCategory) String() string {
switch c {
case networkCategoryPublic:
return "public"
case networkCategoryPrivate:
return "private"
case networkCategoryDomainAuthenticated:
return "domain"
}
return fmt.Sprintf("unknown(%d)", c)
}
// parseNetworkCategory accepts the user-supplied tun.network_category. A
// second return of false means "leave the category alone".
func parseNetworkCategory(s string) (networkCategory, bool, error) {
switch strings.ToLower(strings.TrimSpace(s)) {
case "", "unset":
return 0, false, nil
case "public":
return networkCategoryPublic, true, nil
case "private":
return networkCategoryPrivate, true, nil
case "domain", "domainauthenticated":
return networkCategoryDomainAuthenticated, true, nil
}
return 0, false, fmt.Errorf("unknown tun.network_category %q (expected public, private, domain, or unset)", s)
}
// CLSID_NetworkListManager {DCB00C01-570F-4A9B-8D69-199FDBA5723B}
var clsidNetworkListManager = windows.GUID{
Data1: 0xDCB00C01, Data2: 0x570F, Data3: 0x4A9B,
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
}
// IID_INetworkListManager {DCB00000-570F-4A9B-8D69-199FDBA5723B}
var iidINetworkListManager = windows.GUID{
Data1: 0xDCB00000, Data2: 0x570F, Data3: 0x4A9B,
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
}
// x/sys/windows doesn't expose CoCreateInstance, so we bind it ourselves.
var procCoCreateInstance = windows.NewLazySystemDLL("ole32.dll").NewProc("CoCreateInstance")
const clsCtxAll = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER |
windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER
const (
hrSFALSE = 0x00000001
hrRPCEChangedMode = 0x80010106
)
type hresult uint32
func (h hresult) failed() bool { return int32(h) < 0 }
func (h hresult) String() string {
return fmt.Sprintf("HRESULT 0x%08x", uint32(h))
}
var errAdapterNotFound = errors.New("adapter not present in network connections enumeration")
// Vtable layouts. Slot order must match the declaration order in netlistmgr.h.
// All NLM interfaces here derive from IDispatch, which derives from IUnknown.
type iUnknownVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
}
type iDispatchVtbl struct {
iUnknownVtbl
GetTypeInfoCount uintptr
GetTypeInfo uintptr
GetIDsOfNames uintptr
Invoke uintptr
}
type iNetworkListManagerVtbl struct {
iDispatchVtbl
GetNetworks uintptr
GetNetwork uintptr
GetNetworkConnections uintptr
GetNetworkConnection uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
}
type iNetworkListManager struct{ Vtbl *iNetworkListManagerVtbl }
func (n *iNetworkListManager) Release() {
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
}
func (n *iNetworkListManager) GetNetworkConnections() (*iEnumNetworkConnections, error) {
var enum *iEnumNetworkConnections
r1, _, _ := syscall.SyscallN(n.Vtbl.GetNetworkConnections,
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&enum)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("INetworkListManager.GetNetworkConnections: %s", hr)
}
return enum, nil
}
type iEnumNetworkConnectionsVtbl struct {
iDispatchVtbl
NewEnum uintptr
Next uintptr
Skip uintptr
Reset uintptr
Clone uintptr
}
type iEnumNetworkConnections struct{ Vtbl *iEnumNetworkConnectionsVtbl }
func (e *iEnumNetworkConnections) Release() {
syscall.SyscallN(e.Vtbl.Release, uintptr(unsafe.Pointer(e)))
}
// Next returns the next connection, or (nil, nil) at the end of the enumeration.
func (e *iEnumNetworkConnections) Next() (*iNetworkConnection, error) {
var conn *iNetworkConnection
var fetched uint32
r1, _, _ := syscall.SyscallN(e.Vtbl.Next,
uintptr(unsafe.Pointer(e)), 1,
uintptr(unsafe.Pointer(&conn)), uintptr(unsafe.Pointer(&fetched)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("IEnumNetworkConnections.Next: %s", hr)
}
if fetched == 0 {
return nil, nil
}
return conn, nil
}
type iNetworkConnectionVtbl struct {
iDispatchVtbl
GetNetwork uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
GetConnectionId uintptr
GetAdapterId uintptr
GetDomainType uintptr
}
type iNetworkConnection struct{ Vtbl *iNetworkConnectionVtbl }
func (c *iNetworkConnection) Release() {
syscall.SyscallN(c.Vtbl.Release, uintptr(unsafe.Pointer(c)))
}
func (c *iNetworkConnection) GetAdapterId() (windows.GUID, error) {
var g windows.GUID
r1, _, _ := syscall.SyscallN(c.Vtbl.GetAdapterId,
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&g)),
)
if hr := hresult(r1); hr.failed() {
return windows.GUID{}, fmt.Errorf("INetworkConnection.GetAdapterId: %s", hr)
}
return g, nil
}
func (c *iNetworkConnection) GetNetwork() (*iNetwork, error) {
var net *iNetwork
r1, _, _ := syscall.SyscallN(c.Vtbl.GetNetwork,
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&net)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("INetworkConnection.GetNetwork: %s", hr)
}
return net, nil
}
type iNetworkVtbl struct {
iDispatchVtbl
GetName uintptr
SetName uintptr
GetDescription uintptr
SetDescription uintptr
GetNetworkId uintptr
GetDomainType uintptr
GetNetworkConnections uintptr
GetTimeCreatedAndConnected uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
GetCategory uintptr
SetCategory uintptr
}
type iNetwork struct{ Vtbl *iNetworkVtbl }
func (n *iNetwork) Release() {
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
}
func (n *iNetwork) GetCategory() (networkCategory, error) {
var c networkCategory
r1, _, _ := syscall.SyscallN(n.Vtbl.GetCategory,
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&c)),
)
if hr := hresult(r1); hr.failed() {
return 0, fmt.Errorf("INetwork.GetCategory: %s", hr)
}
return c, nil
}
func (n *iNetwork) SetCategory(c networkCategory) error {
r1, _, _ := syscall.SyscallN(n.Vtbl.SetCategory,
uintptr(unsafe.Pointer(n)), uintptr(int32(c)),
)
if hr := hresult(r1); hr.failed() {
return fmt.Errorf("INetwork.SetCategory: %s", hr)
}
return nil
}
// coInit initializes COM for the current OS thread. The returned function must
// be deferred to balance a successful init. RPC_E_CHANGED_MODE means COM is
// already initialized in a different mode on this thread, which is still fine
// for our calls but we must not Uninitialize in that case.
func coInit() (func(), error) {
err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED)
if err == nil {
return windows.CoUninitialize, nil
}
if e, ok := err.(syscall.Errno); ok {
switch uint32(e) {
case hrSFALSE:
return windows.CoUninitialize, nil
case hrRPCEChangedMode:
return func() {}, nil
}
}
return nil, fmt.Errorf("CoInitializeEx: %w", err)
}
func createNetworkListManager() (*iNetworkListManager, error) {
var nlm *iNetworkListManager
r1, _, _ := procCoCreateInstance.Call(
uintptr(unsafe.Pointer(&clsidNetworkListManager)),
0,
uintptr(clsCtxAll),
uintptr(unsafe.Pointer(&iidINetworkListManager)),
uintptr(unsafe.Pointer(&nlm)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("CoCreateInstance(NetworkListManager): %s", hr)
}
return nlm, nil
}
// setNetworkCategory locates the network connection bound to adapterGUID and
// sets the category of its parent network. Returns errAdapterNotFound if the
// adapter is not yet visible in the NLM enumeration.
func setNetworkCategory(adapterGUID windows.GUID, cat networkCategory) error {
deinit, err := coInit()
if err != nil {
return err
}
defer deinit()
nlm, err := createNetworkListManager()
if err != nil {
return err
}
defer nlm.Release()
enum, err := nlm.GetNetworkConnections()
if err != nil {
return err
}
defer enum.Release()
for {
conn, err := enum.Next()
if err != nil {
return err
}
if conn == nil {
return errAdapterNotFound
}
guid, err := conn.GetAdapterId()
if err != nil || guid != adapterGUID {
conn.Release()
continue
}
net, err := conn.GetNetwork()
conn.Release()
if err != nil {
return err
}
err = net.SetCategory(cat)
net.Release()
return err
}
}
// applyNetworkCategory polls until the wintun adapter shows up in the NLM
// enumeration, then sets the category. Intended to run in its own goroutine.
func applyNetworkCategory(l *slog.Logger, adapterGUID windows.GUID, cat networkCategory) {
// COM Init/Uninit must be paired on the same OS thread.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
const (
attempts = 30
interval = 500 * time.Millisecond
)
for i := 0; i < attempts; i++ {
err := setNetworkCategory(adapterGUID, cat)
if err == nil {
l.Info("Set Windows network category", "category", cat.String())
return
}
if !errors.Is(err, errAdapterNotFound) {
l.Warn("Failed to set Windows network category", "error", err, "category", cat.String())
return
}
time.Sleep(interval)
}
l.Warn("Gave up waiting for adapter to appear in NLM enumeration; network category not set",
"category", cat.String(),
"waited", time.Duration(attempts)*interval,
)
}
+109
View File
@@ -0,0 +1,109 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"testing"
)
func Test_parseNetworkCategory(t *testing.T) {
cases := []struct {
in string
wantCat networkCategory
wantApply bool
wantErr bool
}{
{"", 0, false, false},
{"unset", 0, false, false},
{" UNSET ", 0, false, false},
{"private", networkCategoryPrivate, true, false},
{"Private", networkCategoryPrivate, true, false},
{" PRIVATE ", networkCategoryPrivate, true, false},
{"public", networkCategoryPublic, true, false},
{"PUBLIC", networkCategoryPublic, true, false},
{"domain", networkCategoryDomainAuthenticated, true, false},
{"DomainAuthenticated", networkCategoryDomainAuthenticated, true, false},
{"garbage", 0, false, true},
{"privates", 0, false, true},
}
for _, tc := range cases {
cat, apply, err := parseNetworkCategory(tc.in)
if (err != nil) != tc.wantErr {
t.Errorf("parseNetworkCategory(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr)
continue
}
if cat != tc.wantCat || apply != tc.wantApply {
t.Errorf("parseNetworkCategory(%q) = (%v, %v), want (%v, %v)", tc.in, cat, apply, tc.wantCat, tc.wantApply)
}
}
}
// Test_NLM_round_trip exercises every COM call path used by setNetworkCategory
// without mutating the host's network state. It validates the CLSID/IID
// constants and every vtable index by enumerating connections, fetching the
// adapter id and parent network, reading the current category, and writing it
// back unchanged.
//
// Requires Windows but does not require admin or the wintun driver. Skips if
// no network connections are available (unlikely outside of an isolated
// container).
func Test_NLM_round_trip(t *testing.T) {
deinit, err := coInit()
if err != nil {
t.Fatalf("coInit: %v", err)
}
defer deinit()
nlm, err := createNetworkListManager()
if err != nil {
t.Fatalf("createNetworkListManager: %v", err)
}
defer nlm.Release()
enum, err := nlm.GetNetworkConnections()
if err != nil {
t.Fatalf("GetNetworkConnections: %v", err)
}
defer enum.Release()
saw := 0
for {
conn, err := enum.Next()
if err != nil {
t.Fatalf("EnumNetworkConnections.Next: %v", err)
}
if conn == nil {
break
}
saw++
if _, err := conn.GetAdapterId(); err != nil {
conn.Release()
t.Fatalf("INetworkConnection.GetAdapterId: %v", err)
}
net, err := conn.GetNetwork()
conn.Release()
if err != nil {
t.Fatalf("INetworkConnection.GetNetwork: %v", err)
}
cat, err := net.GetCategory()
if err != nil {
net.Release()
t.Fatalf("INetwork.GetCategory: %v", err)
}
// Set to the current value so the host's NLM state is unchanged but
// SetCategory's vtable slot is still validated end-to-end.
if err := net.SetCategory(cat); err != nil {
net.Release()
t.Fatalf("INetwork.SetCategory(%v): %v", cat, err)
}
net.Release()
}
if saw == 0 {
t.Skip("no NLM network connections available; skipping round-trip")
}
}
+23
View File
@@ -0,0 +1,23 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
package overlay
import (
"log/slog"
"github.com/slackhq/nebula/wfp"
)
// installInterfaceBypass installs a WFP PERMIT filter scoped to the wintun interface LUID so inbound traffic on the
// nebula adapter bypasses Windows Defender Firewall.
func installInterfaceBypass(l *slog.Logger, luid uint64) closer {
s, err := wfp.PermitInterface(luid)
if err != nil {
l.Warn("Failed to install WFP bypass filters on nebula interface", "error", err)
return nil
}
l.Info("Installed WFP filters bypassing Windows Defender Firewall on nebula interface")
return s
}
+11
View File
@@ -0,0 +1,11 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import "log/slog"
// installInterfaceBypass is a no-op on windows-386 because we don't currently build for it.
func installInterfaceBypass(_ *slog.Logger, _ uint64) closer {
return nil
}
+49 -5
View File
@@ -15,6 +15,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/udp"
) )
type TestTun struct { type TestTun struct {
@@ -54,9 +55,12 @@ func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTu
return nil, fmt.Errorf("newTunFromFd not supported") return nil, fmt.Errorf("newTunFromFd not supported")
} }
// Send will place a byte array onto the receive queue for nebula to consume // Send will place a byte array onto the receive queue for nebula to consume.
// These are unencrypted ip layer frames destined for another nebula node. // These are unencrypted ip layer frames destined for another nebula node.
// packets should exit the udp side, capture them with udpConn.Get // packets should exit the udp side, capture them with udpConn.Get.
//
// Send copies the input via the freelist, so the caller is free to mutate
// or reuse it after the call returns.
func (t *TestTun) Send(packet []byte) { func (t *TestTun) Send(packet []byte) {
if t.closed.Load() { if t.closed.Load() {
return return
@@ -65,7 +69,9 @@ func (t *TestTun) Send(packet []byte) {
if t.l.Enabled(context.Background(), slog.LevelDebug) { if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Tun receiving injected packet", "dataLen", len(packet)) t.l.Debug("Tun receiving injected packet", "dataLen", len(packet))
} }
t.rxPackets <- packet buf := acquireTunBuf(len(packet))
copy(buf, packet)
t.rxPackets <- buf
} }
// Get will pull an unencrypted ip layer frame from the transmit queue // Get will pull an unencrypted ip layer frame from the transmit queue
@@ -110,12 +116,44 @@ func (t *TestTun) Write(b []byte) (n int, err error) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
packet := make([]byte, len(b), len(b)) packet := acquireTunBuf(len(b))
copy(packet, b) copy(packet, b)
t.TxPackets <- packet t.TxPackets <- packet
return len(b), nil return len(b), nil
} }
// ReleaseTunBuf returns a slice from TxPackets to the harness freelist, don't use the bytes after the call.
// Channel-backed instead of sync.Pool because putting a []byte in a sync.Pool escapes the slice header to heap.
func ReleaseTunBuf(b []byte) {
if b == nil {
return
}
select {
case tunBufFreelist <- b:
default:
// Freelist full; drop the buffer for the GC.
}
}
// tunBufFreelist retains the backing arrays for TestTun.Write so steady-state allocation drops to zero once the
// freelist has saturated for the current MTU.
var tunBufFreelist = make(chan []byte, 64)
func acquireTunBuf(n int) []byte {
var b []byte
select {
case b = <-tunBufFreelist:
default:
b = make([]byte, 0, udp.MTU)
}
if cap(b) < n {
b = make([]byte, n)
} else {
b = b[:n]
}
return b
}
func (t *TestTun) Close() error { func (t *TestTun) Close() error {
if t.closed.CompareAndSwap(false, true) { if t.closed.CompareAndSwap(false, true) {
close(t.rxPackets) close(t.rxPackets)
@@ -129,8 +167,14 @@ func (t *TestTun) Read(b []byte) (int, error) {
if !ok { if !ok {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
n := len(p)
copy(b, p) copy(b, p)
return len(p), nil // Send always pushes a freelist-acquired slice, return it once we've copied the bytes into the caller's buffer.
select {
case tunBufFreelist <- p:
default:
}
return n, nil
} }
func (t *TestTun) SupportsMultiqueue() bool { func (t *TestTun) SupportsMultiqueue() bool {
+44 -6
View File
@@ -25,6 +25,10 @@ import (
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
) )
type closer interface {
Close()
}
const tunGUIDLabel = "Fixed Nebula Windows GUID v1" const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct { type winTun struct {
@@ -33,6 +37,11 @@ type winTun struct {
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
guid windows.GUID
networkCategory networkCategory
setCategory bool
bypassWDF bool
wdfBypass closer
l *slog.Logger l *slog.Logger
tun *wintun.NativeTun tun *wintun.NativeTun
@@ -54,10 +63,19 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w
return nil, fmt.Errorf("generate GUID failed: %w", err) return nil, fmt.Errorf("generate GUID failed: %w", err)
} }
cat, setCat, err := parseNetworkCategory(c.GetString("tun.network_category", "private"))
if err != nil {
return nil, err
}
t := &winTun{ t := &winTun{
Device: deviceName, Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
guid: *guid,
networkCategory: cat,
setCategory: setCat,
bypassWDF: c.GetBool("tun.windows_bypass_wdf", true),
l: l, l: l,
} }
@@ -142,6 +160,17 @@ func (t *winTun) Activate() error {
return err return err
} }
if t.setCategory {
// The wintun adapter takes a moment to register with the Network List
// Manager, so we apply the category in the background and retry until
// it shows up.
go applyNetworkCategory(t.l, t.guid, t.networkCategory)
}
if t.bypassWDF {
t.wdfBypass = installInterfaceBypass(t.l, uint64(t.tun.LUID()))
}
return nil return nil
} }
@@ -156,11 +185,8 @@ func (t *winTun) addRoutes(logErrors bool) error {
continue continue
} }
// Add our unsafe route // Add our unsafe route as an on-link route to the nebula tun device.
// Windows does not support multipath routes natively, so we install only a single route. err := luid.AddRoute(r.Cidr, unspecifiedNextHop(r.Cidr), uint32(r.Metric))
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
if err != nil { if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
@@ -206,7 +232,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
} }
// See comment on luid.AddRoute // See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) err := luid.DeleteRoute(r.Cidr, unspecifiedNextHop(r.Cidr))
if err != nil { if err != nil {
t.l.Error("Failed to remove route", "error", err, "route", r) t.l.Error("Failed to remove route", "error", err, "route", r)
} else { } else {
@@ -258,9 +284,21 @@ func (t *winTun) Close() error {
_ = luid.FlushDNS(windows.AF_INET) _ = luid.FlushDNS(windows.AF_INET)
_ = luid.FlushDNS(windows.AF_INET6) _ = luid.FlushDNS(windows.AF_INET6)
if t.wdfBypass != nil {
t.wdfBypass.Close()
t.wdfBypass = nil
}
return t.tun.Close() return t.tun.Close()
} }
func unspecifiedNextHop(p netip.Prefix) netip.Addr {
if p.Addr().Is4() {
return netip.IPv4Unspecified()
}
return netip.IPv6Unspecified()
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) { func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit // GUID is 128 bit
hash := crypto.MD5.New() hash := crypto.MD5.New()
+3 -5
View File
@@ -99,12 +99,10 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
var currentState *CertState var currentState *CertState
if initial { if initial {
cipher = c.GetString("cipher", "aes") cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch cipher { switch cipher {
case "aes": case "aes", "chachapoly":
noiseEndianness = binary.BigEndian // Each post-handshake CipherState in noiseutil hardcodes its own
case "chachapoly": // nonce endianness now, so there's nothing to set up here.
noiseEndianness = binary.LittleEndian
default: default:
return util.NewContextualError( return util.NewContextualError(
"unknown cipher", "unknown cipher",
+158 -33
View File
@@ -1,24 +1,70 @@
package nebula package nebula
import ( import (
"context"
"log/slog" "log/slog"
"net/netip"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
) )
// holepunchQueueSize buffers the channel that pending holepunchJobs land on after their delay timer fires.
const holepunchQueueSize = 64
// holepunchJob is one scheduled item delivered to the worker goroutine.
// - target valid -> send a UDP punch to target. vpnAddr, if set, is the peer's vpn addr carried for log context.
// - target invalid, vpnAddr valid -> send an encrypted test packet to vpnAddr (a "punchback").
type holepunchJob struct {
target netip.AddrPort
vpnAddr netip.Addr
}
// lighthouseChecker is the slice of LightHouse that Punchy actually needs.
// Defined here so Punchy doesn't take a *LightHouse dependency (LightHouse
// already holds a *Punchy, and the bidirectional pointer reference is awkward
// even within the same package). Tests can also substitute a fake.
type lighthouseChecker interface {
IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool
}
type Punchy struct { type Punchy struct {
punch atomic.Bool punch atomic.Bool
respond atomic.Bool respond atomic.Bool
delay atomic.Int64 delay atomic.Int64
respondDelay atomic.Int64 respondDelay atomic.Int64
punchEverything atomic.Bool punchEverything atomic.Bool
sched *Scheduler[holepunchJob]
punchConn udp.Conn
metricHolepunchTx metrics.Counter
metricPunchyTx metrics.Counter
ctx context.Context
ifce EncWriter
hm *HostMap
lh lighthouseChecker
l *slog.Logger l *slog.Logger
} }
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { func NewPunchyFromConfig(l *slog.Logger, c *config.C, punchConn udp.Conn) *Punchy {
p := &Punchy{l: l} p := &Punchy{
l: l,
punchConn: punchConn,
sched: NewScheduler[holepunchJob](holepunchQueueSize),
metricPunchyTx: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
}
if c.GetBool("stats.lighthouse_metrics", false) {
p.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
} else {
p.metricHolepunchTx = metrics.NilCounter{}
}
p.reload(c, true) p.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
@@ -29,7 +75,7 @@ func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
} }
func (p *Punchy) reload(c *config.C, initial bool) { func (p *Punchy) reload(c *config.C, initial bool) {
if initial { if initial || c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
var yes bool var yes bool
if c.IsSet("punchy.punch") { if c.IsSet("punchy.punch") {
yes = c.GetBool("punchy.punch", false) yes = c.GetBool("punchy.punch", false)
@@ -38,16 +84,15 @@ func (p *Punchy) reload(c *config.C, initial bool) {
yes = c.GetBool("punchy", false) yes = c.GetBool("punchy", false)
} }
p.punch.Store(yes) old := p.punch.Swap(yes)
if yes { switch {
case initial && yes:
p.l.Info("punchy enabled") p.l.Info("punchy enabled")
} else { case initial:
p.l.Info("punchy disabled") p.l.Info("punchy disabled")
case old != yes:
p.l.Info("punchy.punch changed", "punch", yes)
} }
} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
} }
if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") { if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") {
@@ -59,52 +104,132 @@ func (p *Punchy) reload(c *config.C, initial bool) {
yes = c.GetBool("punch_back", false) yes = c.GetBool("punch_back", false)
} }
p.respond.Store(yes) old := p.respond.Swap(yes)
if !initial && old != yes {
if !initial { p.l.Info("punchy.respond changed", "respond", yes)
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
} }
} }
//NOTE: this will not apply to any in progress operations, only the next one //NOTE: this will not apply to any in progress operations, only the next one
if initial || c.HasChanged("punchy.delay") { if initial || c.HasChanged("punchy.delay") {
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) newDelay := int64(c.GetDuration("punchy.delay", time.Second))
if !initial { old := p.delay.Swap(newDelay)
p.l.Info("punchy.delay changed", "delay", p.GetDelay()) if !initial && old != newDelay {
p.l.Info("punchy.delay changed", "delay", time.Duration(newDelay))
} }
} }
if initial || c.HasChanged("punchy.target_all_remotes") { if initial || c.HasChanged("punchy.target_all_remotes") {
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) yes := c.GetBool("punchy.target_all_remotes", false)
if !initial { old := p.punchEverything.Swap(yes)
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything()) if !initial && old != yes {
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", yes)
} }
} }
if initial || c.HasChanged("punchy.respond_delay") { if initial || c.HasChanged("punchy.respond_delay") {
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) newDelay := int64(c.GetDuration("punchy.respond_delay", 5*time.Second))
if !initial { old := p.respondDelay.Swap(newDelay)
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay()) if !initial && old != newDelay {
p.l.Info("punchy.respond_delay changed", "respond_delay", time.Duration(newDelay))
} }
} }
} }
func (p *Punchy) GetPunch() bool { // Schedule queues a punch packet to target, to be sent after the configured delay.
return p.punch.Load() // vpnAddr is the peer's vpn addr, used for log context when the packet actually fires.
// No-op if target is not a valid AddrPort or if Start has not yet been called. Safe to call from any goroutine.
func (p *Punchy) Schedule(target netip.AddrPort, vpnAddr netip.Addr) {
if !target.IsValid() || p.ctx == nil {
return
}
p.scheduleJob(holepunchJob{target: target, vpnAddr: vpnAddr}, time.Duration(p.delay.Load()))
} }
func (p *Punchy) GetRespond() bool { // ScheduleRespond queues a punchback test packet to vpnAddr after the configured respond delay,
return p.respond.Load() // gated on punchy.respond. No-op when respond is disabled or before Start has been called.
func (p *Punchy) ScheduleRespond(vpnAddr netip.Addr) {
if !p.respond.Load() || p.ctx == nil {
return
}
p.scheduleJob(holepunchJob{vpnAddr: vpnAddr}, time.Duration(p.respondDelay.Load()))
} }
func (p *Punchy) GetDelay() time.Duration { // scheduleJob delegates to the pooled Scheduler.
return (time.Duration)(p.delay.Load()) // The callback observes p.ctx so a job that becomes due after Stop is dropped instead of queued.
func (p *Punchy) scheduleJob(job holepunchJob, delay time.Duration) {
p.sched.Schedule(p.ctx, job, delay)
} }
func (p *Punchy) GetRespondDelay() time.Duration { // SendPunch sends an immediate keepalive punch for an idle hostinfo.
return (time.Duration)(p.respondDelay.Load()) // The configured punchy.target_all_remotes mode picks the targets. Gated on punchy.punch and the lighthouse-skip rule
// (lighthouses don't get keepalive punches because the regular update interval keeps their NAT state warm).
func (p *Punchy) SendPunch(hostinfo *HostInfo) {
if !p.punch.Load() {
return
}
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
return
} }
func (p *Punchy) GetTargetEverything() bool { if p.punchEverything.Load() {
return p.punchEverything.Load() p.sendPunchToAllRemotes(hostinfo)
} else if hostinfo.remote.IsValid() {
p.metricPunchyTx.Inc(1)
p.punchConn.WriteTo([]byte{1}, hostinfo.remote)
}
}
// SendPunchToAll punches every known remote for hostinfo, but only when punchy.target_all_remotes is enabled.
// The connection manager calls this during outbound-only traffic: the outbound traffic itself keeps the primary's
// NAT state warm, but non-primary remotes need separate refresh, so we fan out to all of them (the redundant
// primary punch is harmless). Gated on punchy.punch and the lighthouse-skip rule.
func (p *Punchy) SendPunchToAll(hostinfo *HostInfo) {
if !p.punchEverything.Load() {
return
}
if !p.punch.Load() {
return
}
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
return
}
p.sendPunchToAllRemotes(hostinfo)
}
func (p *Punchy) sendPunchToAllRemotes(hostinfo *HostInfo) {
hostinfo.remotes.ForEach(p.hm.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
p.metricPunchyTx.Inc(1)
p.punchConn.WriteTo([]byte{1}, addr)
})
}
// Start wires the runtime dependencies and spawns the scheduler worker.
func (p *Punchy) Start(ctx context.Context, ifce EncWriter, hm *HostMap, lh lighthouseChecker) {
p.ctx = ctx
p.ifce = ifce
p.hm = hm
p.lh = lh
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
empty := []byte{0}
go p.sched.Run(ctx, func(job holepunchJob) {
switch {
case job.target.IsValid():
if p.l.Enabled(context.Background(), slog.LevelDebug) {
p.l.Debug("Punching", "target", job.target, "vpnAddr", job.vpnAddr)
}
p.metricHolepunchTx.Inc(1)
p.punchConn.WriteTo(empty, job.target)
case job.vpnAddr.IsValid():
// A nebula test packet to the host trying to contact us.
// In the case of a double nat or other difficult scenario, this may help establish a tunnel.
if p.l.Enabled(context.Background(), slog.LevelDebug) {
p.l.Debug("Sending a nebula test packet", "vpnAddr", job.vpnAddr)
}
p.ifce.SendMessageToVpnAddr(header.Test, header.TestRequest, job.vpnAddr, []byte(""), nb, out)
}
})
} }
+40 -41
View File
@@ -17,42 +17,42 @@ func TestNewPunchyFromConfig(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
// Test defaults // Test defaults
p := NewPunchyFromConfig(test.NewLogger(), c) p := NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.False(t, p.GetPunch()) assert.False(t, p.punch.Load())
assert.False(t, p.GetRespond()) assert.False(t, p.respond.Load())
assert.Equal(t, time.Second, p.GetDelay()) assert.Equal(t, time.Second, time.Duration(p.delay.Load()))
assert.Equal(t, 5*time.Second, p.GetRespondDelay()) assert.Equal(t, 5*time.Second, time.Duration(p.respondDelay.Load()))
// punchy deprecation // punchy deprecation
c.Settings["punchy"] = true c.Settings["punchy"] = true
p = NewPunchyFromConfig(test.NewLogger(), c) p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.GetPunch()) assert.True(t, p.punch.Load())
// punchy.punch // punchy.punch
c.Settings["punchy"] = map[string]any{"punch": true} c.Settings["punchy"] = map[string]any{"punch": true}
p = NewPunchyFromConfig(test.NewLogger(), c) p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.GetPunch()) assert.True(t, p.punch.Load())
// punch_back deprecation // punch_back deprecation
c.Settings["punch_back"] = true c.Settings["punch_back"] = true
p = NewPunchyFromConfig(test.NewLogger(), c) p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.GetRespond()) assert.True(t, p.respond.Load())
// punchy.respond // punchy.respond
c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punchy"] = map[string]any{"respond": true}
c.Settings["punch_back"] = false c.Settings["punch_back"] = false
p = NewPunchyFromConfig(test.NewLogger(), c) p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.GetRespond()) assert.True(t, p.respond.Load())
// punchy.delay // punchy.delay
c.Settings["punchy"] = map[string]any{"delay": "1m"} c.Settings["punchy"] = map[string]any{"delay": "1m"}
p = NewPunchyFromConfig(test.NewLogger(), c) p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.Equal(t, time.Minute, p.GetDelay()) assert.Equal(t, time.Minute, time.Duration(p.delay.Load()))
// punchy.respond_delay // punchy.respond_delay
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
p = NewPunchyFromConfig(test.NewLogger(), c) p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.Equal(t, time.Minute, p.GetRespondDelay()) assert.Equal(t, time.Minute, time.Duration(p.respondDelay.Load()))
} }
func TestPunchy_reload(t *testing.T) { func TestPunchy_reload(t *testing.T) {
@@ -61,35 +61,34 @@ func TestPunchy_reload(t *testing.T) {
delay, _ := time.ParseDuration("1m") delay, _ := time.ParseDuration("1m")
require.NoError(t, c.LoadString(` require.NoError(t, c.LoadString(`
punchy: punchy:
punch: false
delay: 1m delay: 1m
respond: false respond: false
`)) `))
p := NewPunchyFromConfig(test.NewLogger(), c) p := NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.Equal(t, delay, p.GetDelay()) assert.False(t, p.punch.Load())
assert.False(t, p.GetRespond()) assert.Equal(t, delay, time.Duration(p.delay.Load()))
assert.False(t, p.respond.Load())
newDelay, _ := time.ParseDuration("10m") newDelay, _ := time.ParseDuration("10m")
require.NoError(t, c.ReloadConfigString(` require.NoError(t, c.ReloadConfigString(`
punchy: punchy:
punch: true
delay: 10m delay: 10m
respond: true respond: true
`)) `))
p.reload(c, false) p.reload(c, false)
assert.Equal(t, newDelay, p.GetDelay()) assert.True(t, p.punch.Load())
assert.True(t, p.GetRespond()) assert.Equal(t, newDelay, time.Duration(p.delay.Load()))
assert.True(t, p.respond.Load())
} }
// The tests below pin the shape of each log line Punchy produces so changes // The tests below pin the shape of each log line Punchy produces so changes
// cannot silently break whatever operators are grepping for. The assertions // cannot silently break whatever operators are grepping for. The assertions
// are on the structured message + attrs (e.g. "punchy.respond changed" with // are on the structured message + attrs (e.g. "punchy.respond changed" with
// a respond=true field) rather than a formatted string. // a respond=true field) rather than a formatted string. Tests filter by
// // message rather than asserting total entry counts so unrelated info lines
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is // are tolerated without being locked into the format.
// not supported" warning whenever any key under punchy changes, because of
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
// punchy form. The tests filter by message rather than asserting total
// entry counts so that warning is tolerated without being locked into
// the format.
type capturedEntry struct { type capturedEntry struct {
Level slog.Level Level slog.Level
@@ -145,7 +144,7 @@ func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
c := config.NewC(test.NewLogger()) c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: true}`)) require.NoError(t, c.LoadString(`punchy: {punch: true}`))
NewPunchyFromConfig(l, c) NewPunchyFromConfig(l, c, nil)
entry := findEntry(t, hook.entries, "punchy enabled") entry := findEntry(t, hook.entries, "punchy enabled")
assert.Equal(t, slog.LevelInfo, entry.Level) assert.Equal(t, slog.LevelInfo, entry.Level)
@@ -157,32 +156,32 @@ func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
c := config.NewC(test.NewLogger()) c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`)) require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c) NewPunchyFromConfig(l, c, nil)
entry := findEntry(t, hook.entries, "punchy disabled") entry := findEntry(t, hook.entries, "punchy disabled")
assert.Equal(t, slog.LevelInfo, entry.Level) assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs) assert.Empty(t, entry.Attrs)
} }
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) { func TestPunchy_LogFormat_ReloadPunch(t *testing.T) {
l, hook := newCapturingPunchyLogger(t) l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger()) c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`)) require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c) NewPunchyFromConfig(l, c, nil)
hook.entries = nil hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`)) require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.") entry := findEntry(t, hook.entries, "punchy.punch changed")
assert.Equal(t, slog.LevelWarn, entry.Level) assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs) assert.Equal(t, map[string]any{"punch": true}, entry.Attrs)
} }
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) { func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
l, hook := newCapturingPunchyLogger(t) l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger()) c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond: false}`)) require.NoError(t, c.LoadString(`punchy: {respond: false}`))
NewPunchyFromConfig(l, c) NewPunchyFromConfig(l, c, nil)
hook.entries = nil hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`)) require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
@@ -196,7 +195,7 @@ func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t) l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger()) c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`)) require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
NewPunchyFromConfig(l, c) NewPunchyFromConfig(l, c, nil)
hook.entries = nil hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`)) require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
@@ -210,7 +209,7 @@ func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
l, hook := newCapturingPunchyLogger(t) l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger()) c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`)) require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
NewPunchyFromConfig(l, c) NewPunchyFromConfig(l, c, nil)
hook.entries = nil hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`)) require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
@@ -224,7 +223,7 @@ func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t) l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger()) c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`)) require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
NewPunchyFromConfig(l, c) NewPunchyFromConfig(l, c, nil)
hook.entries = nil hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`)) require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
+175 -4
View File
@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/netip" "net/netip"
"slices"
"sync/atomic" "sync/atomic"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
@@ -18,6 +19,7 @@ type relayManager struct {
l *slog.Logger l *slog.Logger
hostmap *HostMap hostmap *HostMap
amRelay atomic.Bool amRelay atomic.Bool
useRelays atomic.Bool
} }
func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager { func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager {
@@ -36,8 +38,10 @@ func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *c
} }
func (rm *relayManager) reload(c *config.C, initial bool) error { func (rm *relayManager) reload(c *config.C, initial bool) error {
if initial || c.HasChanged("relay.am_relay") { if initial || c.HasChanged("relay.am_relay") || c.HasChanged("relay.use_relays") {
rm.setAmRelay(c.GetBool("relay.am_relay", false)) amRelay := c.GetBool("relay.am_relay", false)
rm.amRelay.Store(amRelay)
rm.useRelays.Store(c.GetBool("relay.use_relays", true) && !amRelay)
} }
return nil return nil
} }
@@ -46,8 +50,175 @@ func (rm *relayManager) GetAmRelay() bool {
return rm.amRelay.Load() return rm.amRelay.Load()
} }
func (rm *relayManager) setAmRelay(v bool) { func (rm *relayManager) GetUseRelays() bool {
rm.amRelay.Store(v) return rm.useRelays.Load()
}
// StartRelays drives the relay-establishment side of an outbound handshake attempt.
// For each candidate relay it either kicks off a handshake to the relay, sends a CreateRelayRequest, retransmits
// one that may have been lost, or, once the relay is Established, forwards the in-progress
// stage 0 handshake packet for vpnIp through it.
func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hh *HandshakeHostInfo, stage0 []byte) {
hostinfo := hh.hostinfo
if !rm.GetUseRelays() || len(hostinfo.remotes.relays) == 0 {
hh.lastRelays = nil
return
}
relays := hostinfo.remotes.relays
listLevel := slog.LevelDebug
prior := hh.lastRelays
if !slices.Equal(relays, prior) {
listLevel = slog.LevelInfo
hh.lastRelays = slices.Clone(relays)
}
hl := hostinfo.logger(rm.l)
hl.Log(context.Background(), listLevel, "Attempt to relay through hosts", "relays", relays)
// Send a RelayRequest to all known Relay IP's
for _, relay := range relays {
// Don't relay through the host I'm trying to connect to
if relay == vpnIp {
continue
}
// Don't relay to myself
if f.myVpnAddrsTable.Contains(relay) {
continue
}
// Each relay's per-attempt log fires at Info on the first time we hit it and Debug after that.
level := slog.LevelInfo
if slices.Contains(prior, relay) {
level = slog.LevelDebug
}
relayHostInfo := rm.hostmap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hl.Log(context.Background(), level, "Establish tunnel to relay target", "relay", relay.String())
f.Handshake(relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through
existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
if !ok {
// No relays exist or requested yet.
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(rm.l, relayHostInfo, rm.hostmap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hl.Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
}
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !f.myVpnAddrs[0].Is4() {
hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hl.Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hl.Error("Failed to marshal Control message to create relay", "error", err)
} else {
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.Log(context.Background(), level, "send CreateRelayRequest",
"relayFrom", f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", idx,
"relay", relay,
)
}
}
continue
}
switch existingRelay.State {
case Established:
hl.Log(context.Background(), level, "Send handshake via relay", "relay", relay.String())
f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested:
hl.Log(context.Background(), level, "Re-send CreateRelay request", "relay", relay.String())
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !f.myVpnAddrs[0].Is4() {
hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hl.Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hl.Error("Failed to marshal Control message to create relay", "error", err)
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.Log(context.Background(), level, "send CreateRelayRequest",
"relayFrom", f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", existingRelay.LocalIndex,
"relay", relay,
)
}
case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default:
hl.Error("Relay unexpected state",
"vpnIp", vpnIp,
"state", existingRelay.State,
"relay", relay,
)
}
}
} }
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
+97
View File
@@ -0,0 +1,97 @@
package nebula
import (
"bytes"
"log/slog"
"net/netip"
"testing"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
// TestStartRelaysLogDedupe verifies that repeated attempts with the same relay set drop the log
// chatter to Debug, mirroring how the normal handshake retry loop quiets down once it's already
// announced its targets.
func TestStartRelaysLogDedupe(t *testing.T) {
vpnIp := netip.MustParseAddr("100.64.99.4")
otherRelay := netip.MustParseAddr("100.64.99.5")
newHH := func() *HandshakeHostInfo {
// Use the target's own vpnIp as the "relay" so the loop body skips it without
// touching any sender-side state. That isolates the test to the level-selection
// behavior of the top-level "Attempt to relay through hosts" log.
hostinfo := &HostInfo{
vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1,
remotes: NewRemoteList([]netip.Addr{vpnIp}, nil),
}
hostinfo.remotes.relays = []netip.Addr{vpnIp}
return &HandshakeHostInfo{hostinfo: hostinfo}
}
// Park any extra relay addresses we'll introduce mid-test in myVpnAddrsTable so the loop
// body always skips before touching f.Handshake (which would need a real handshakeManager).
addrTable := new(bart.Lite)
addrTable.Insert(netip.PrefixFrom(otherRelay, otherRelay.BitLen()))
f := &Interface{myVpnAddrsTable: addrTable}
newRM := func(buf *bytes.Buffer) *relayManager {
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
rm := &relayManager{l: l, hostmap: newHostMap(l)}
rm.useRelays.Store(true)
return rm
}
const msg = `msg="Attempt to relay through hosts"`
t.Run("first attempt logs at Info", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
hh := newHH()
rm.StartRelays(f, vpnIp, hh, nil)
assert.Equal(t, []netip.Addr{vpnIp}, hh.lastRelays, "lastRelays should record the relay set we just attempted")
assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info level on first attempt")
})
t.Run("repeat attempt with same relays drops to Debug", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
hh := newHH()
rm.StartRelays(f, vpnIp, hh, nil)
first := append([]netip.Addr(nil), hh.lastRelays...)
buf.Reset()
rm.StartRelays(f, vpnIp, hh, nil)
assert.Equal(t, first, hh.lastRelays)
assert.Contains(t, buf.String(), "level=DEBUG "+msg, "expected Debug level on identical retry")
assert.NotContains(t, buf.String(), "level=INFO "+msg, "Info should not fire on identical retry")
})
t.Run("changed relay list bumps back to Info", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
hh := newHH()
rm.StartRelays(f, vpnIp, hh, nil)
buf.Reset()
// The lighthouse handed us a new set this round.
hh.hostinfo.remotes.relays = []netip.Addr{vpnIp, otherRelay}
rm.StartRelays(f, vpnIp, hh, nil)
assert.Equal(t, []netip.Addr{vpnIp, otherRelay}, hh.lastRelays)
assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info when the relay list changes")
})
t.Run("disabled relays clears lastRelays and emits no Attempt log", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
rm.useRelays.Store(false)
hh := newHH()
hh.lastRelays = []netip.Addr{vpnIp}
rm.StartRelays(f, vpnIp, hh, nil)
assert.Nil(t, hh.lastRelays, "with relays disabled lastRelays should be cleared")
assert.NotContains(t, buf.String(), msg, "should not log when we shortcut out")
})
}
+25
View File
@@ -239,6 +239,31 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
r.hr = hr r.hr = hr
} }
// ResetForOwner zeros the reported address slices for the given owner and marks the addrs list dirty.
// Any pending hostname resolution will be canceled.
func (r *RemoteList) ResetForOwner(ownerVpnAddr netip.Addr) {
r.Lock()
defer r.Unlock()
r.hr.Cancel()
if c, ok := r.cache[ownerVpnAddr]; ok {
if c.v4 != nil {
c.v4.reported = c.v4.reported[:0]
}
if c.v6 != nil {
c.v6.reported = c.v6.reported[:0]
}
}
r.shouldRebuild = true
}
// ClearHostnameResults cancels the in-flight DNS resolver goroutine (if any) and drops the resolved IP cache.
func (r *RemoteList) ClearHostnameResults() {
r.Lock()
defer r.Unlock()
r.unlockedSetHostnamesResults(nil)
r.shouldRebuild = true
}
// Len locks and reports the size of the deduplicated address list // Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges // The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) Len(preferredRanges []netip.Prefix) int { func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
+89
View File
@@ -6,8 +6,22 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
// trackedHostnameResults builds a *hostnamesResults with a known cancel function and a
// pre-populated ips map so tests can assert cancellation and verify previously-resolved
// IPs survive a cancel without spinning up a real DNS resolver.
func trackedHostnameResults(cancelFn func(), addrs ...string) *hostnamesResults {
hr := &hostnamesResults{cancelFn: cancelFn}
ips := map[netip.AddrPort]struct{}{}
for _, a := range addrs {
ips[netip.MustParseAddrPort(a)] = struct{}{}
}
hr.ips.Store(&ips)
return hr
}
func TestRemoteList_Rebuild(t *testing.T) { func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4( rl.unlockedSetV4(
@@ -112,6 +126,81 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
} }
func TestRemoteList_ResetForOwner(t *testing.T) {
ourselves := netip.MustParseAddr("10.0.0.1")
otherOwner := netip.MustParseAddr("10.0.0.2")
vpnAddr := netip.MustParseAddr("10.0.0.99")
rl := NewRemoteList([]netip.Addr{vpnAddr}, nil)
rl.unlockedSetV4(ourselves, vpnAddr,
[]*V4AddrPort{newIp4AndPortFromString("1.1.1.1:4242")},
func(netip.Addr, *V4AddrPort) bool { return true },
)
rl.unlockedSetV6(ourselves, vpnAddr,
[]*V6AddrPort{newIp6AndPortFromString("[1::1]:4242")},
func(netip.Addr, *V6AddrPort) bool { return true },
)
rl.unlockedSetV4(otherOwner, vpnAddr,
[]*V4AddrPort{newIp4AndPortFromString("2.2.2.2:4242")},
func(netip.Addr, *V4AddrPort) bool { return true },
)
canceled := 0
hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")
rl.Lock()
rl.unlockedSetHostnamesResults(hr)
rl.Unlock()
rl.ResetForOwner(ourselves)
rl.RLock()
defer rl.RUnlock()
assert.Empty(t, rl.cache[ourselves].v4.reported, "our v4 reported should be cleared")
assert.Empty(t, rl.cache[ourselves].v6.reported, "our v6 reported should be cleared")
assert.Len(t, rl.cache[otherOwner].v4.reported, 1, "other owner's contribution must be preserved")
assert.Equal(t, "2.2.2.2:4242", protoV4AddrPortToNetAddrPort(rl.cache[otherOwner].v4.reported[0]).String())
assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled")
assert.Same(t, hr, rl.hr, "hostnamesResults must be preserved so DNS-resolved IPs keep feeding addrs until replaced")
assert.NotEmpty(t, rl.hr.GetAddrs(), "previously-resolved IPs should still be readable after cancel")
assert.True(t, rl.shouldRebuild, "shouldRebuild must be set so the next Rebuild recomputes addrs")
}
func TestRemoteList_ResetForOwner_NoEntry(t *testing.T) {
// An owner with no cache entry must not panic; shouldRebuild is still set and any
// existing hostnamesResults is canceled.
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil)
canceled := 0
rl.Lock()
rl.unlockedSetHostnamesResults(trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242"))
rl.Unlock()
rl.ResetForOwner(netip.MustParseAddr("10.0.0.1"))
rl.RLock()
defer rl.RUnlock()
assert.Equal(t, 1, canceled)
assert.True(t, rl.shouldRebuild)
}
func TestRemoteList_ClearHostnameResults(t *testing.T) {
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil)
canceled := 0
hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")
rl.Lock()
rl.unlockedSetHostnamesResults(hr)
rl.Unlock()
require.NotEmpty(t, hr.GetAddrs(), "hostnamesResults should have its fastrack IPs populated")
rl.ClearHostnameResults()
rl.RLock()
defer rl.RUnlock()
assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled")
assert.Nil(t, rl.hr, "hostnamesResults should be dropped")
assert.True(t, rl.shouldRebuild)
}
func BenchmarkFullRebuild(b *testing.B) { func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4( rl.unlockedSetV4(
+84
View File
@@ -0,0 +1,84 @@
package nebula
import (
"context"
"sync"
"time"
)
// Scheduler is an allocation-conscious dispatch primitive for delayed work.
// Pending items are handed to time.AfterFunc, and ready items land on a worker
// channel for centralized dispatch in fire-time order.
//
// Pick a Scheduler when fire timing matters (exact deadlines, no bucketing) or when the scheduling
// rate is uneven enough that idle CPU matters. Each fire is a runtime-spawned goroutine running the callback before
// delivering to the worker, which is fine at sparse rates but adds up at line rate.
//
// Pick a TimerWheel when scheduling is high-rate and uniform: its O(1) insert, internal item cache,
// and bucket-batched dispatch are cheaper at scale.
// The caller drives the tick loop (Advance/Purge) and pays for fires at bucket boundaries rather than exact deadlines.
type Scheduler[T any] struct {
queue chan T
pool sync.Pool
}
type schedItem[T any] struct {
val T
ctx context.Context
s *Scheduler[T]
timer *time.Timer
fire func()
}
// NewScheduler builds a Scheduler whose worker channel is sized to queueSize.
// The buffer absorbs bursts of timers firing close together without
// blocking the runtime's callback goroutines on the worker.
func NewScheduler[T any](queueSize int) *Scheduler[T] {
s := &Scheduler[T]{
queue: make(chan T, queueSize),
}
s.pool.New = func() any {
si := &schedItem[T]{s: s}
// fire is allocated exactly once per pool-resident item.
// The closure captures only `si`, which stays stable for the item's lifetime.
si.fire = func() {
select {
case si.s.queue <- si.val:
case <-si.ctx.Done():
}
var zero T
si.val = zero
si.ctx = nil
si.s.pool.Put(si)
}
return si
}
return s
}
// Schedule arranges item to be delivered to the worker after delay.
// The runtime's timer heap handles the wait, so the scheduler itself burns no CPU while idle.
// The callback observes ctx: if ctx is cancelled before the timer fires, the item is dropped instead of queued.
func (s *Scheduler[T]) Schedule(ctx context.Context, item T, delay time.Duration) {
si := s.pool.Get().(*schedItem[T])
si.val = item
si.ctx = ctx
if si.timer == nil {
si.timer = time.AfterFunc(delay, si.fire)
} else {
si.timer.Reset(delay)
}
}
// Run drains the worker queue, calling fn for each item. Returns when ctx is cancelled.
// Tests that want deterministic timing should drive the queue directly rather than going through Schedule + Run.
func (s *Scheduler[T]) Run(ctx context.Context, fn func(T)) {
for {
select {
case <-ctx.Done():
return
case item := <-s.queue:
fn(item)
}
}
}
+79
View File
@@ -0,0 +1,79 @@
package nebula
import (
"context"
"testing"
"time"
)
func TestScheduler_PooledReuse(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := NewScheduler[int](16)
delivered := make(chan int, 256)
go s.Run(ctx, func(item int) { delivered <- item })
const N = 100
for i := 0; i < N; i++ {
s.Schedule(ctx, i, time.Millisecond)
}
deadline := time.After(2 * time.Second)
got := 0
for got < N {
select {
case <-delivered:
got++
case <-deadline:
t.Fatalf("only %d/%d items delivered", got, N)
}
}
}
// BenchmarkScheduler_Schedule reports allocations per Schedule call.
// In steady state the Scheduler's sync.Pool means we should see zero allocs per op once the pool warms up.
func BenchmarkScheduler_Schedule(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := NewScheduler[int](b.N)
go s.Run(ctx, func(int) {})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.Schedule(ctx, i, time.Microsecond)
}
}
// BenchmarkBareAfterFunc is the comparison baseline.
// What we'd pay per Schedule if Punchy called time.AfterFunc directly without the pooled Scheduler.
// Allocates a *time.Timer plus a closure each call.
func BenchmarkBareAfterFunc(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
queue := make(chan int, b.N)
go func() {
for {
select {
case <-ctx.Done():
return
case <-queue:
}
}
}()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
i := i
time.AfterFunc(time.Microsecond, func() {
select {
case queue <- i:
case <-ctx.Done():
}
})
}
}
+37 -20
View File
@@ -27,21 +27,20 @@ type SSHServer struct {
commands *radix.Tree commands *radix.Tree
listener net.Listener listener net.Listener
// Call the cancel() function to stop all active sessions // ctx parents per-Run contexts. Cancelling it (e.g. via Control.Stop) tears the server down even
// across reloads, since each Run derives a fresh child rather than reusing this one directly.
ctx context.Context ctx context.Context
cancel func()
} }
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen.
func NewSSHServer(l *slog.Logger) (*SSHServer, error) { // The ssh server's context is parented off the supplied ctx so cancelling it
// (e.g. on Control.Stop) tears down active sessions and closes the listener.
ctx, cancel := context.WithCancel(context.Background()) func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) {
s := &SSHServer{ s := &SSHServer{
trustedKeys: make(map[string]map[string]bool), trustedKeys: make(map[string]map[string]bool),
l: l, l: l,
commands: radix.New(), commands: radix.New(),
ctx: ctx, ctx: ctx,
cancel: cancel,
} }
cc := ssh.CertChecker{ cc := ssh.CertChecker{
@@ -151,28 +150,51 @@ func (s *SSHServer) RegisterCommand(c *Command) {
s.commands.Insert(c.Name, c) s.commands.Insert(c.Name, c)
} }
// Run begins listening and accepting connections // Run begins listening and accepting connections. Each invocation derives a fresh per-Run context
// from the constructor-supplied ctx so a Stop+Run sequence (used by config reload) starts clean
// rather than carrying a permanently-cancelled context across runs.
func (s *SSHServer) Run(addr string) error { func (s *SSHServer) Run(addr string) error {
var err error if s.ctx.Err() != nil {
s.listener, err = net.Listen("tcp", addr) return s.ctx.Err()
}
listener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return err return err
} }
// s.listener is the public handle Stop uses to interrupt the active run; listener (the local) is what
// this run owns. They start equal but a fast reload may overwrite s.listener with the next run's
// listener before this run's watcher fires, so each run must close its own listener via the local
// reference.
s.listener = listener
runCtx, cancel := context.WithCancel(s.ctx)
defer cancel()
// Close the listener when this run's context is cancelled. That can come from the parent
// (Control.Stop), from Run returning normally (defer cancel above), or transitively when a sibling
// run cancels through Stop closing the listener. net.Listener.Close is idempotent so a duplicate
// close from Stop is benign.
go func() {
<-runCtx.Done()
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
s.l.Warn("Failed to close the sshd listener", "error", err)
}
}()
s.l.Info("SSH server is listening", "sshListener", addr) s.l.Info("SSH server is listening", "sshListener", addr)
// Run loops until there is an error // Run loops until there is an error
s.run() s.run(runCtx, listener)
s.closeSessions()
s.l.Info("SSH server stopped listening") s.l.Info("SSH server stopped listening")
// We don't return an error because run logs for us // We don't return an error because run logs for us
return nil return nil
} }
func (s *SSHServer) run() { func (s *SSHServer) run(ctx context.Context, listener net.Listener) {
for { for {
c, err := s.listener.Accept() c, err := listener.Accept()
if err != nil { if err != nil {
if !errors.Is(err, net.ErrClosed) { if !errors.Is(err, net.ErrClosed) {
s.l.Warn("Error in listener, shutting down", "error", err) s.l.Warn("Error in listener, shutting down", "error", err)
@@ -184,7 +206,7 @@ func (s *SSHServer) run() {
// Ensure that a bad client doesn't hurt us by checking for the parent context // Ensure that a bad client doesn't hurt us by checking for the parent context
// cancellation before calling NewServerConn, and forcing the socket to close when // cancellation before calling NewServerConn, and forcing the socket to close when
// the context is cancelled. // the context is cancelled.
sessionContext, sessionCancel := context.WithCancel(s.ctx) sessionContext, sessionCancel := context.WithCancel(ctx)
go func() { go func() {
<-sessionContext.Done() <-sessionContext.Done()
c.Close() c.Close()
@@ -227,14 +249,9 @@ func (s *SSHServer) run() {
} }
func (s *SSHServer) Stop() { func (s *SSHServer) Stop() {
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
if s.listener != nil { if s.listener != nil {
if err := s.listener.Close(); err != nil { if err := s.listener.Close(); err != nil {
s.l.Warn("Failed to close the sshd listener", "error", err) s.l.Warn("Failed to close the sshd listener", "error", err)
} }
} }
} }
func (s *SSHServer) closeSessions() {
s.cancel()
}
+17
View File
@@ -8,6 +8,23 @@ import (
// How many timer objects should be cached // How many timer objects should be cached
const timerCacheMax = 50000 const timerCacheMax = 50000
// TimerWheel is a hashed timing wheel: a fixed slot array indexed by (now + delay) % wheelLen,
// with each slot a singly linked list of items due in that bucket.
// Adds are O(1), Purges return items in arrival-within-slot order, and an internal cache of TimeoutItems
// keeps steady-state inserts allocation-free.
//
// The TimerWheel does not handle concurrency or lifecycle on its own.
// Callers drive Advance/Purge from their own ticker loop, take their own locks (or use LockingTimerWheel),
// and decide whether to keep ticking when the wheel is empty.
//
// Pick a TimerWheel when scheduling is high-rate and uniform: line-rate conntrack inserts,
// per-tunnel traffic checks at fixed intervals. O(1) insert plus the item cache means the hot path doesn't allocate.
// Items added in the same tick are dispatched together when that slot rotates current,
// which amortizes the cost of waking the worker.
//
// Pick a Scheduler when delay precision matters or scheduling is sparse or uneven.
// The wheel rounds requested timeouts up to its tick resolution and clamps anything beyond its wheel duration;
// both are silent in this implementation.
type TimerWheel[T any] struct { type TimerWheel[T any] struct {
// Current tick // Current tick
current int current int
+1 -2
View File
@@ -5,12 +5,11 @@ package udp
import ( import (
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
"log/slog"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
+1 -2
View File
@@ -8,12 +8,11 @@ package udp
import ( import (
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
"log/slog"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
+57
View File
@@ -0,0 +1,57 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
package udp
import (
"log/slog"
"sync"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/wfp"
)
// wrapWithWDFBypass wraps a Conn so that the first ReloadConfig consults listen.windows_bypass_wdf
// and installs a WFP PERMIT filter for the listener's bound UDP port. The session is released when Close runs.
func wrapWithWDFBypass(l *slog.Logger, conn Conn) Conn {
return &bypassConn{Conn: conn, l: l}
}
type bypassConn struct {
Conn
l *slog.Logger
installOnce sync.Once
session *wfp.Session
}
func (b *bypassConn) ReloadConfig(c *config.C) {
b.installOnce.Do(func() {
if !c.GetBool("listen.windows_bypass_wdf", true) {
return
}
addr, err := b.Conn.LocalAddr()
if err != nil {
b.l.Warn("Failed to query listener port for WFP bypass", "error", err)
return
}
s, err := wfp.PermitUDPPort(addr.Port())
if err != nil {
b.l.Warn("Failed to install WFP bypass filters for listener", "error", err)
return
}
b.l.Info("Installed WFP filters bypassing Windows Defender Firewall on UDP listener port",
"port", addr.Port())
b.session = s
})
b.Conn.ReloadConfig(c)
}
func (b *bypassConn) Close() error {
if b.session != nil {
b.session.Close()
b.session = nil
}
return b.Conn.Close()
}
+11
View File
@@ -0,0 +1,11 @@
//go:build !e2e_testing
// +build !e2e_testing
package udp
import "log/slog"
// wrapWithWDFBypass is a no-op on windows-386 since we don't currently build for it.
func wrapWithWDFBypass(_ *slog.Logger, conn Conn) Conn {
return conn
}
+1 -2
View File
@@ -7,12 +7,11 @@ package udp
import ( import (
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
"log/slog"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
+50 -13
View File
@@ -21,17 +21,48 @@ type Packet struct {
Data []byte Data []byte
} }
// Copy returns a fresh *Packet (from the freelist) with a duplicate Data buffer.
func (u *Packet) Copy() *Packet { func (u *Packet) Copy() *Packet {
n := &Packet{ n := acquirePacket()
To: u.To, n.To = u.To
From: u.From, n.From = u.From
Data: make([]byte, len(u.Data)), if cap(n.Data) < len(u.Data) {
n.Data = make([]byte, len(u.Data))
} else {
n.Data = n.Data[:len(u.Data)]
} }
copy(n.Data, u.Data) copy(n.Data, u.Data)
return n return n
} }
// Release returns p to the harness packet freelist.
// Callers that pull a *Packet from Get / TxPackets must Release when done.
// Channel-backed instead of sync.Pool because sync.Pool's per-P caches drain badly under cross-goroutine Get/Put,
// and putting a []byte in a Pool escapes the slice header to heap.
func (p *Packet) Release() {
if p == nil {
return
}
p.Data = p.Data[:0]
select {
case packetFreelist <- p:
default:
// Freelist full; drop the *Packet for the GC.
}
}
// packetFreelist retains *Packet structs (and their backing Data arrays) so steady-state allocation drops to zero.
var packetFreelist = make(chan *Packet, 64)
func acquirePacket() *Packet {
select {
case p := <-packetFreelist:
return p
default:
return &Packet{}
}
}
type TesterConn struct { type TesterConn struct {
Addr netip.AddrPort Addr netip.AddrPort
@@ -64,13 +95,15 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn,
// this is an encrypted packet or a handshake message in most cases // this is an encrypted packet or a handshake message in most cases
// packets were transmitted from another nebula node, you can send them with Tun.Send // packets were transmitted from another nebula node, you can send them with Tun.Send
func (u *TesterConn) Send(packet *Packet) { func (u *TesterConn) Send(packet *Packet) {
h := &header.H{} if u.l.Enabled(context.Background(), slog.LevelDebug) {
// Parse the header only under debug logging, otherwise the
// allocation would show up in every Send call.
var h header.H
if err := h.Parse(packet.Data); err != nil { if err := h.Parse(packet.Data); err != nil {
panic(err) panic(err)
} }
if u.l.Enabled(context.Background(), slog.LevelDebug) {
u.l.Debug("UDP receiving injected packet", u.l.Debug("UDP receiving injected packet",
"header", h, "header", &h,
"udpAddr", packet.From, "udpAddr", packet.From,
"dataLen", len(packet.Data), "dataLen", len(packet.Data),
) )
@@ -107,15 +140,18 @@ func (u *TesterConn) Get(block bool) *Packet {
//********************************************************************************************************************// //********************************************************************************************************************//
func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
p := &Packet{ p := acquirePacket()
Data: make([]byte, len(b), len(b)), if cap(p.Data) < len(b) {
From: u.Addr, p.Data = make([]byte, len(b))
To: addr, } else {
p.Data = p.Data[:len(b)]
} }
copy(p.Data, b) copy(p.Data, b)
p.From = u.Addr
p.To = addr
select { select {
case <-u.done: case <-u.done:
p.Release()
return io.ErrClosedPipe return io.ErrClosedPipe
case u.TxPackets <- p: case u.TxPackets <- p:
return nil return nil
@@ -129,6 +165,7 @@ func (u *TesterConn) ListenOut(r EncReader) error {
return os.ErrClosed return os.ErrClosed
case p := <-u.RxPackets: case p := <-u.RxPackets:
r(p.From, p.Data) r(p.From, p.Data)
p.Release()
} }
} }
} }
+9 -4
View File
@@ -19,13 +19,18 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int)
return nil, fmt.Errorf("multiple udp listeners not supported on windows") return nil, fmt.Errorf("multiple udp listeners not supported on windows")
} }
var conn Conn
rc, err := NewRIOListener(l, ip, port) rc, err := NewRIOListener(l, ip, port)
if err == nil { if err == nil {
return rc, nil conn = rc
} } else {
l.Error("Falling back to standard udp sockets", "error", err) l.Error("Falling back to standard udp sockets", "error", err)
return NewGenericListener(l, ip, port, multi, batch) conn, err = NewGenericListener(l, ip, port, multi, batch)
if err != nil {
return nil, err
}
}
return wrapWithWDFBypass(l, conn), nil
} }
func NewListenConfig(multi bool) net.ListenConfig { func NewListenConfig(multi bool) net.ListenConfig {
+377
View File
@@ -0,0 +1,377 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
// Package wfp installs Windows Filtering Platform (WFP) PERMIT filters in a dynamic, session-scoped sublayer.
// Because WFP sits below Windows Defender Firewall, a high-weight permit at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4/V6 lets
// the matching inbound traffic through regardless of WDF rules.
//
// Each Session owns its own engine handle. When the handle closes, every dynamic object added during the session
// is auto-deleted by Windows, so there are no orphaned filters.
//
// Type definitions and constants are derived from the wireguard-windows firewall package (MIT).
// Only the subset we exercise is reproduced.
package wfp
import (
"fmt"
"unsafe"
"golang.org/x/sys/windows"
)
// FWPM layer GUIDs (fwpmu.h).
//
// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = e1cd9fe7-f4b5-4273-96c0-592e487b8650
// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = a3b42c97-9f04-4672-b87e-cee9c483257f
var (
fwpmLayerAleAuthRecvAcceptV4 = windows.GUID{
Data1: 0xe1cd9fe7, Data2: 0xf4b5, Data3: 0x4273,
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
}
fwpmLayerAleAuthRecvAcceptV6 = windows.GUID{
Data1: 0xa3b42c97, Data2: 0x9f04, Data3: 0x4672,
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
}
)
// FWPM_CONDITION_IP_LOCAL_INTERFACE = 4cd62a49-59c3-4969-b7f3-bda5d32890a4
var fwpmConditionIPLocalInterface = windows.GUID{
Data1: 0x4cd62a49, Data2: 0x59c3, Data3: 0x4969,
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
}
// FWPM_CONDITION_IP_PROTOCOL = 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
var fwpmConditionIPProtocol = windows.GUID{
Data1: 0x3971ef2b, Data2: 0x623e, Data3: 0x4f9a,
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
}
// FWPM_CONDITION_IP_LOCAL_PORT = 0c1ba1af-5765-453f-af22-a8f791ac775b
var fwpmConditionIPLocalPort = windows.GUID{
Data1: 0x0c1ba1af, Data2: 0x5765, Data3: 0x453f,
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
}
// IPPROTO_UDP from in.h.
const ipprotoUDP uint8 = 17
// FWP_ACTION_TYPE values (fwptypes.h). PERMIT is terminating.
const fwpActionPermit uint32 = 0x00001002 // 0x2 | FWP_ACTION_FLAG_TERMINATING(0x1000)
// FWP_DATA_TYPE values we use.
const (
fwpEmpty uint32 = 0
fwpUint8 uint32 = 1
fwpUint16 uint32 = 2
fwpUint64 uint32 = 4
)
// FWP_MATCH_TYPE values.
const fwpMatchEqual uint32 = 0
// FWPM_SESSION flags.
const fwpmSessionFlagDynamic uint32 = 0x1
// FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT prevents lower-priority filters in other sublayers,
// notably Windows Defender Firewall's MPSSVC_WF sublayer, which shares our 0xFFFF weight from overriding this PERMIT.
// Without it, a default WDF block at the same sublayer weight can still win arbitration.
const fwpmFilterFlagClearActionRight uint32 = 0x8
// RPC authentication.
// RPC_C_AUTHN_WINNT works on workgroup machines with no domain context
// RPC_C_AUTHN_DEFAULT falls back through a chain that can land on something WFP doesn't accept on a fresh box.
const rpcCAuthnWinNT uint32 = 10
// fwpByteBlob (FWP_BYTE_BLOB). 16 bytes on 64-bit.
type fwpByteBlob struct {
size uint32
_ uint32 // padding
data *uint8
}
// fwpValue0 / FWP_CONDITION_VALUE0 layout. 16 bytes on 64-bit.
// The union is pointer-sized; types <= 32 bits (UINT8/16/32, INT8/16/32, float) live inline in the low bytes
// of `value`, while UINT64/INT64/double and aggregate types are stored *by pointer*, even on 64-bit, where the
// union member is declared as UINT64*. So when populating an FWP_UINT64 condition, pass
// uintptr(unsafe.Pointer(&luidVar)) instead of the LUID inline.
type fwpValue0 struct {
type_ uint32
_ uint32 // padding before union to 8-byte alignment
value uintptr
}
// fwpmDisplayData0 / FWPM_DISPLAY_DATA0. 16 bytes on 64-bit.
type fwpmDisplayData0 struct {
name *uint16
description *uint16
}
// fwpmAction0 / FWPM_ACTION0. 20 bytes; no leading padding because actionType
// is uint32 and GUID's first field is uint32.
type fwpmAction0 struct {
actionType uint32
filterType windows.GUID
}
// fwpmFilterCondition0. 40 bytes on 64-bit.
type fwpmFilterCondition0 struct {
fieldKey windows.GUID // 16
matchType uint32 // 4
_ uint32 // 4 padding
conditionValue fwpValue0 // 16
}
// fwpmFilter0. 200 bytes on 64-bit.
type fwpmFilter0 struct {
filterKey windows.GUID
displayData fwpmDisplayData0
flags uint32
_ uint32 // padding before *GUID
providerKey *windows.GUID
providerData fwpByteBlob
layerKey windows.GUID
subLayerKey windows.GUID
weight fwpValue0
numFilterConditions uint32
_ uint32 // padding before pointer
filterCondition *fwpmFilterCondition0
action fwpmAction0
_ [4]byte // layout correction
providerContextKey windows.GUID
reserved *windows.GUID
filterID uint64
effectiveWeight fwpValue0
}
// fwpmSublayer0. 72 bytes on 64-bit.
type fwpmSublayer0 struct {
subLayerKey windows.GUID
displayData fwpmDisplayData0
flags uint32
_ uint32 // padding before *GUID
providerKey *windows.GUID
providerData fwpByteBlob
weight uint16
_ [6]byte // padding to 72 bytes
}
// fwpmSession0. 72 bytes on 64-bit.
type fwpmSession0 struct {
sessionKey windows.GUID
displayData fwpmDisplayData0
flags uint32
txnWaitTimeoutInMSec uint32
processId uint32
_ uint32 // padding before *SID
sid *windows.SID
username *uint16
kernelMode uint8
_ [7]byte // tail padding
}
// fwpuclnt.dll bindings. Only the calls we use.
var (
modFwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
procFwpmEngineOpen0 = modFwpuclnt.NewProc("FwpmEngineOpen0")
procFwpmEngineClose0 = modFwpuclnt.NewProc("FwpmEngineClose0")
procFwpmSubLayerAdd0 = modFwpuclnt.NewProc("FwpmSubLayerAdd0")
procFwpmFilterAdd0 = modFwpuclnt.NewProc("FwpmFilterAdd0")
)
// Session holds the WFP engine handle for a single bypass operation. The handle owns a dynamic session:
// when it is closed, every WFP object added during the session (sublayer + filters) is automatically deleted by
// Windows. That gives us correct cleanup even if the host process is killed hard between Permit* and Close.
type Session struct {
engine uintptr
}
// Close releases the engine handle. Windows deletes every dynamic object (sublayer + filters) the session installed.
// Safe to call on a nil receiver.
func (s *Session) Close() {
if s == nil || s.engine == 0 {
return
}
procFwpmEngineClose0.Call(s.engine)
s.engine = 0
}
// PermitInterface installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to the given network
// interface LUID. Inbound traffic on that interface bypasses Windows Defender Firewall.
func PermitInterface(luid uint64) (*Session, error) {
s, sublayerKey, err := newSession()
if err != nil {
return nil, err
}
if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, luid); err != nil {
s.Close()
return nil, fmt.Errorf("add v4 filter: %w", err)
}
if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, luid); err != nil {
s.Close()
return nil, fmt.Errorf("add v6 filter: %w", err)
}
return s, nil
}
// PermitUDPPort installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to UDP traffic with the
// given local port. Inbound UDP to that port on any interface bypasses Windows Defender Firewall.
func PermitUDPPort(port uint16) (*Session, error) {
s, sublayerKey, err := newSession()
if err != nil {
return nil, err
}
if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, port); err != nil {
s.Close()
return nil, fmt.Errorf("add v4 filter: %w", err)
}
if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, port); err != nil {
s.Close()
return nil, fmt.Errorf("add v6 filter: %w", err)
}
return s, nil
}
func newSession() (*Session, windows.GUID, error) {
engine, err := openDynamicEngine()
if err != nil {
return nil, windows.GUID{}, err
}
sublayerKey, err := registerSublayer(engine)
if err != nil {
procFwpmEngineClose0.Call(engine)
return nil, windows.GUID{}, err
}
return &Session{engine: engine}, sublayerKey, nil
}
func openDynamicEngine() (uintptr, error) {
session := fwpmSession0{flags: fwpmSessionFlagDynamic}
var engine uintptr
r1, _, _ := procFwpmEngineOpen0.Call(
0, // serverName == NULL (local)
uintptr(rpcCAuthnWinNT),
0, // authIdentity == NULL
uintptr(unsafe.Pointer(&session)),
uintptr(unsafe.Pointer(&engine)),
)
if r1 != 0 {
return 0, fmt.Errorf("FwpmEngineOpen0: 0x%x", r1)
}
return engine, nil
}
// registerSublayer adds a session-scoped sublayer with a freshly generated GUID, weight 0xFFFF so its filters arbitrate
// above WDF's default sublayer. The sublayer is dynamic (no PERSISTENT flag) and goes away when the engine handle closes.
func registerSublayer(engine uintptr) (windows.GUID, error) {
key, err := windows.GenerateGUID()
if err != nil {
return windows.GUID{}, fmt.Errorf("GenerateGUID for sublayer: %w", err)
}
name, _ := windows.UTF16PtrFromString("Nebula WDF bypass sublayer")
desc, _ := windows.UTF16PtrFromString("Permit filters bypassing Windows Defender Firewall")
sl := fwpmSublayer0{
subLayerKey: key,
displayData: fwpmDisplayData0{name: name, description: desc},
weight: 0xFFFF,
}
r1, _, _ := procFwpmSubLayerAdd0.Call(
engine,
uintptr(unsafe.Pointer(&sl)),
0, // sd == NULL
)
if r1 != 0 {
return windows.GUID{}, fmt.Errorf("FwpmSubLayerAdd0: 0x%x", r1)
}
return key, nil
}
func addInterfaceFilter(engine uintptr, sublayerKey, layer windows.GUID, luid uint64) error {
name, _ := windows.UTF16PtrFromString("Nebula allow interface inbound")
desc, _ := windows.UTF16PtrFromString("Permits inbound traffic on a nebula interface")
// luid must remain addressable through the syscall -- FWP_UINT64 is stored
// by pointer in the FWP_VALUE0 union.
cond := fwpmFilterCondition0{
fieldKey: fwpmConditionIPLocalInterface,
matchType: fwpMatchEqual,
conditionValue: fwpValue0{
type_: fwpUint64,
value: uintptr(unsafe.Pointer(&luid)),
},
}
filter := fwpmFilter0{
// filterKey left zero: WFP assigns one when the filter is added.
displayData: fwpmDisplayData0{name: name, description: desc},
flags: fwpmFilterFlagClearActionRight,
layerKey: layer,
subLayerKey: sublayerKey,
weight: fwpValue0{type_: fwpUint8, value: uintptr(15)},
numFilterConditions: 1,
filterCondition: &cond,
action: fwpmAction0{actionType: fwpActionPermit},
}
r1, _, _ := procFwpmFilterAdd0.Call(
engine,
uintptr(unsafe.Pointer(&filter)),
0, // sd == NULL
0, // id == NULL
)
if r1 != 0 {
return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1)
}
return nil
}
// addUDPPortFilter installs a PERMIT filter that matches (IP_PROTOCOL == UDP) AND (IP_LOCAL_PORT == port).
// FWP_UINT8 and FWP_UINT16 are <= 32 bits so they live inline in the FWP_VALUE0 union.
func addUDPPortFilter(engine uintptr, sublayerKey, layer windows.GUID, port uint16) error {
name, _ := windows.UTF16PtrFromString("Nebula allow UDP port inbound")
desc, _ := windows.UTF16PtrFromString("Permits inbound UDP to a nebula listener port")
conds := [2]fwpmFilterCondition0{
{
fieldKey: fwpmConditionIPProtocol,
matchType: fwpMatchEqual,
conditionValue: fwpValue0{
type_: fwpUint8,
value: uintptr(ipprotoUDP),
},
},
{
fieldKey: fwpmConditionIPLocalPort,
matchType: fwpMatchEqual,
conditionValue: fwpValue0{
type_: fwpUint16,
value: uintptr(port),
},
},
}
filter := fwpmFilter0{
displayData: fwpmDisplayData0{name: name, description: desc},
flags: fwpmFilterFlagClearActionRight,
layerKey: layer,
subLayerKey: sublayerKey,
weight: fwpValue0{type_: fwpUint8, value: uintptr(15)},
numFilterConditions: 2,
filterCondition: &conds[0],
action: fwpmAction0{actionType: fwpActionPermit},
}
r1, _, _ := procFwpmFilterAdd0.Call(
engine,
uintptr(unsafe.Pointer(&filter)),
0,
0,
)
if r1 != 0 {
return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1)
}
return nil
}