Merge remote-tracking branch 'origin/master' into fips140

This commit is contained in:
Wade Simmons
2026-06-01 09:52:57 -04:00
95 changed files with 5607 additions and 1384 deletions
+113
View File
@@ -0,0 +1,113 @@
name: Code-sign Windows binaries
description: >
Sign every .exe under a given path in place via the DefinedNet code-signer
Lambda. If `role` or `bucket` is empty, logs a notice and skips signing so
forks and dev branches without AWS access still produce usable builds.
inputs:
path:
description: "Directory whose .exe files should be signed in place"
required: true
role:
description: "IAM role ARN to assume via OIDC; empty disables signing"
required: false
default: ""
bucket:
description: "S3 staging bucket the code-signer Lambda reads from; empty disables signing"
required: false
default: ""
region:
description: "AWS region for the role and Lambda"
required: false
default: "us-east-2"
function-name:
description: "Code-signer Lambda function name"
required: false
default: "code-signer"
key-prefix:
description: "S3 key prefix the caller is authorized to write under"
required: false
default: "code-signing/slackhq/nebula"
runs:
using: composite
steps:
- name: Skip notice
if: inputs.role == '' || inputs.bucket == ''
shell: sh
run: echo "::notice::code-signer role or bucket not set; skipping code signing."
- name: Configure AWS credentials
if: inputs.role != '' && inputs.bucket != ''
uses: aws-actions/configure-aws-credentials@v6
with:
role-to-assume: ${{ inputs.role }}
aws-region: ${{ inputs.region }}
# Default is 12 retries to ride out IAM trust-policy propagation; once
# the role is stable we want a real misconfiguration to fail fast.
retry-max-attempts: 5
- name: Sign .exe files
if: inputs.role != '' && inputs.bucket != ''
shell: sh
env:
SIGN_PATH: ${{ inputs.path }}
BUCKET: ${{ inputs.bucket }}
FUNCTION_NAME: ${{ inputs.function-name }}
KEY_PREFIX: ${{ inputs.key-prefix }}
run: |
set -eu
RUN="${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
find "$SIGN_PATH" -name '*.exe' -print | while read -r path
do
rel=${path#"$SIGN_PATH"/}
file=$(basename "$path")
name=${file%.exe}
prefix="${KEY_PREFIX}/${RUN}"
src="${prefix}/unsigned/${rel}"
dst="${prefix}/signed/${rel}"
echo "::group::Sign ${rel}"
echo "Uploading unsigned to s3://${BUCKET}/${src}"
aws s3 cp --no-progress "$path" "s3://${BUCKET}/${src}" >/dev/null
echo "Invoking ${FUNCTION_NAME} Lambda"
payload=$(jq -nc \
--arg s "$src" \
--arg d "$dst" \
--arg p "$name" \
'{source_key: $s, dest_key: $d, program_name: $p}')
meta=$(aws lambda invoke \
--function-name "$FUNCTION_NAME" \
--cli-binary-format raw-in-base64-out \
--payload "$payload" \
--output json \
/tmp/sign-resp.json)
if echo "$meta" | jq -e '.FunctionError != null' >/dev/null
then
echo "::endgroup::"
echo "::error::code-signer Lambda failed for ${rel}"
cat /tmp/sign-resp.json >&2
exit 1
fi
echo "Downloading signed back to ${path}"
aws s3 cp --no-progress "s3://${BUCKET}/${dst}" "$path" >/dev/null
aws s3 rm "s3://${BUCKET}/${src}" >/dev/null 2>&1 || true
aws s3 rm "s3://${BUCKET}/${dst}" >/dev/null 2>&1 || true
# Sanity-check the bytes we got back actually carry an Authenticode
# signature that this machine can validate end to end.
status=$(powershell -NoProfile -Command "(Get-AuthenticodeSignature -FilePath '$path').Status" | tr -d '\r')
if [ "$status" != "Valid" ]
then
echo "::endgroup::"
echo "::error::${rel} signature status: ${status} (expected Valid)"
exit 1
fi
echo "Signed ${rel} (sha256=$(jq -r '.sha256' /tmp/sign-resp.json), status=${status})"
echo "::endgroup::"
done
-34
View File
@@ -1,34 +0,0 @@
name: gofmt
on:
push:
branches:
- master
pull_request:
paths:
- '.github/workflows/gofmt.yml'
- '**.go'
jobs:
gofmt:
name: Run gofmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Install goimports
run: |
go install golang.org/x/tools/cmd/goimports@latest
- name: gofmt
run: |
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ]
then
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d
exit 1
fi
+18 -8
View File
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release
- name: Upload artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: linux-latest
path: release
@@ -32,6 +32,9 @@ jobs:
build-windows:
name: Build Windows
runs-on: windows-latest
permissions:
id-token: write
contents: read
steps:
- uses: actions/checkout@v6
@@ -54,8 +57,15 @@ jobs:
mkdir build\dist\windows
mv dist\windows\wintun build\dist\windows\
- name: Code-sign
uses: ./.github/actions/code-sign
with:
path: build
role: ${{ secrets.DEFINED_CODE_SIGNER_ROLE }}
bucket: ${{ secrets.DEFINED_CODE_SIGNER_BUCKET }}
- name: Upload artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: windows-latest
path: build
@@ -75,7 +85,7 @@ jobs:
- name: Import certificates
if: env.HAS_SIGNING_CREDS == 'true'
uses: Apple-Actions/import-codesign-certs@v6
uses: Apple-Actions/import-codesign-certs@v7
with:
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
@@ -104,7 +114,7 @@ jobs:
fi
- name: Upload artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: darwin-latest
path: ./release/*
@@ -128,21 +138,21 @@ jobs:
- name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
name: linux-latest
path: artifacts
- name: Login to Docker Hub
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
username: ${{ vars.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v4
- name: Build and push images
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
@@ -163,7 +173,7 @@ jobs:
- uses: actions/checkout@v6
- name: Download artifacts
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
path: artifacts
+81 -16
View File
@@ -14,10 +14,18 @@ on:
- 'go.sum'
jobs:
smoke-extra:
smoke-extra-libvirt:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: Run extra smoke tests
name: ${{ matrix.target }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
target:
- freebsd-amd64
- openbsd-amd64
- netbsd-amd64
- linux-amd64-ipv6disable
env:
VAGRANT_DEFAULT_PROVIDER: libvirt
steps:
@@ -40,28 +48,85 @@ jobs:
sudo chmod 666 /var/run/libvirt/libvirt-sock
vagrant plugin install vagrant-libvirt
- name: freebsd-amd64
run: make smoke-vagrant/freebsd-amd64
- name: ${{ matrix.target }}
run: make smoke-vagrant/${{ matrix.target }}
- name: openbsd-amd64
run: make smoke-vagrant/openbsd-amd64
timeout-minutes: 30
- name: netbsd-amd64
run: make smoke-vagrant/netbsd-amd64
# linux-386 needs VirtualBox, which conflicts with KVM/libvirt -- isolated job.
smoke-extra-virtualbox:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: linux-386
runs-on: ubuntu-latest
env:
VAGRANT_DEFAULT_PROVIDER: virtualbox
steps:
- name: linux-amd64-ipv6disable
run: make smoke-vagrant/linux-amd64-ipv6disable
- uses: actions/checkout@v6
# linux-386 runs last because it requires disabling KVM to use VirtualBox,
# which prevents libvirt (used by the other tests) from working after this point.
- name: install virtualbox for i386 test
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: add hashicorp source
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
- name: install vagrant and virtualbox
run: |
sudo apt-get install -y virtualbox
sudo apt-get update && sudo apt-get install -y vagrant virtualbox
sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true
- name: linux-386
env:
VAGRANT_DEFAULT_PROVIDER: virtualbox
run: make smoke-vagrant/linux-386
timeout-minutes: 30
smoke-windows:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: Run windows smoke test
runs-on: windows-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
# WSL2 + Ubuntu so the smoke can run a real linux peer with its own
# netns. iputils-ping is needed for the in-WSL ping check. WSL1 has no
# real kernel and would lack /dev/net/tun, so we have to force WSL2.
- uses: Vampire/setup-wsl@v3
with:
distribution: Ubuntu-24.04
additional-packages: iputils-ping iproute2
# Vampire/setup-wsl provisions WSL1 even when the WSL2 platform is present.
# Convert the distro to WSL2 explicitly before we try to use /dev/net/tun.
- name: convert distro to WSL2
shell: pwsh
run: |
wsl --set-version Ubuntu-24.04 2
wsl --shutdown
wsl --list --verbose
- name: build windows nebula
run: make bin-windows
- name: build linux nebula for WSL
shell: bash
env:
GOOS: linux
GOARCH: amd64
run: |
mkdir -p build/linux-amd64
go build -o build/linux-amd64/nebula ./cmd/nebula
- name: run smoke-windows
shell: pwsh
working-directory: ./.github/workflows/smoke
run: ./smoke-windows.ps1
timeout-minutes: 15
+272
View File
@@ -0,0 +1,272 @@
#!/usr/bin/env pwsh
# Windows smoke test for the nebula tun + UDP + NLM code paths.
#
# Topology:
# - lighthouse runs natively on the Windows host (wintun + windows UDP)
# - peer runs inside WSL2 (Linux build of nebula, /dev/net/tun)
#
# WSL2 gives us a real netns boundary so the loopback fast-path on Windows
# does not short-circuit the overlay -- when WSL pings the lighthouse VPN IP,
# Linux has no idea that IP is local to the Windows host, so the packet is
# forced through nebula. Same in reverse.
$ErrorActionPreference = 'Stop'
# wsl.exe emits UTF-16 LE by default which PowerShell reads as bytes, mangling
# every captured string. WSL_UTF8 makes wsl.exe emit UTF-8 instead.
$env:WSL_UTF8 = '1'
$RepoRoot = Resolve-Path "$PSScriptRoot\..\..\.."
$Nebula = Join-Path $RepoRoot 'nebula.exe'
$NebulaCert = Join-Path $RepoRoot 'nebula-cert.exe'
$NebulaLinux = Join-Path $RepoRoot 'build\linux-amd64\nebula'
if (-not (Test-Path $Nebula)) { throw "missing $Nebula; run 'make bin-windows' first" }
if (-not (Test-Path $NebulaCert)) { throw "missing $NebulaCert; run 'make bin-windows' first" }
if (-not (Test-Path $NebulaLinux)) { throw "missing $NebulaLinux; build the linux nebula first" }
# Matches the distro installed by Vampire/setup-wsl in smoke-extra.yml.
$Distro = 'Ubuntu-24.04'
$listed = (wsl --list --quiet 2>$null) -join "`n"
if ($listed -notmatch [regex]::Escape($Distro)) {
throw "WSL distro $Distro not registered. Got: $listed"
}
Write-Host "Using WSL distro: $Distro"
# Windows host as seen from inside WSL: WSL's default-route gateway. We extract
# it with a regex rather than awk fields so PowerShell does not eat any '$N'
# tokens, and tabs/double-spaces in `ip route` output do not confuse a cut.
$ipCmd = 'ip route show default | grep -oE "([0-9]+\.){3}[0-9]+" | head -1'
$WindowsIp = (wsl -d $Distro -- bash -c $ipCmd).Trim()
if (-not $WindowsIp) { throw "could not determine Windows host IP from WSL" }
Write-Host "Windows host IP from WSL: $WindowsIp"
$WorkDir = Join-Path $env:TEMP 'nebula-smoke-windows'
if (Test-Path $WorkDir) { Remove-Item -Recurse -Force $WorkDir }
New-Item -ItemType Directory -Path $WorkDir | Out-Null
$WslDir = '/tmp/nebula-smoke'
wsl -d $Distro -- bash -c "rm -rf $WslDir && mkdir -p $WslDir" | Out-Null
$DevName = 'nebula-smoke'
$Ip1 = '192.168.241.1'
$Ip2 = '192.168.241.2'
$Port = 4242
& $NebulaCert ca -name 'smoke-ca' -out-crt "$WorkDir\ca.crt" -out-key "$WorkDir\ca.key"
if ($LASTEXITCODE -ne 0) { throw "nebula-cert ca failed (exit $LASTEXITCODE)" }
& $NebulaCert sign -name 'lighthouse' -networks "$Ip1/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\lighthouse.crt" -out-key "$WorkDir\lighthouse.key"
if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign lighthouse failed (exit $LASTEXITCODE)" }
& $NebulaCert sign -name 'peer' -networks "$Ip2/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\peer.crt" -out-key "$WorkDir\peer.key"
if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign peer failed (exit $LASTEXITCODE)" }
# Windows lighthouse config.
@"
pki:
ca: $WorkDir\ca.crt
cert: $WorkDir\lighthouse.crt
key: $WorkDir\lighthouse.key
static_host_map: {}
lighthouse:
am_lighthouse: true
interval: 60
hosts: []
listen:
host: 0.0.0.0
port: $Port
tun:
disabled: false
dev: $DevName
drop_local_broadcast: false
drop_multicast: false
tx_queue: 500
mtu: 1300
network_category: private
logging:
level: info
format: text
firewall:
outbound_action: drop
inbound_action: drop
conntrack:
tcp_timeout: 12m
udp_timeout: 3m
default_timeout: 10m
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
"@ | Out-File -FilePath "$WorkDir\lighthouse.yml" -Encoding utf8
# WSL peer config (paths are POSIX, deliberately).
@"
pki:
ca: $WslDir/ca.crt
cert: $WslDir/peer.crt
key: $WslDir/peer.key
static_host_map:
"${Ip1}": ["${WindowsIp}:$Port"]
lighthouse:
am_lighthouse: false
interval: 60
hosts:
- "${Ip1}"
listen:
host: 0.0.0.0
port: 0
tun:
disabled: false
dev: nebula1
drop_local_broadcast: false
drop_multicast: false
tx_queue: 500
mtu: 1300
logging:
level: info
format: text
firewall:
outbound_action: drop
inbound_action: drop
conntrack:
tcp_timeout: 12m
udp_timeout: 3m
default_timeout: 10m
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
"@ | Out-File -FilePath "$WorkDir\peer.yml" -Encoding utf8
# Stage WSL artifacts. Convert Windows paths to WSL paths ourselves rather than
# calling `wslpath`, because PowerShell's argument-passing to external EXEs
# strips backslashes from path arguments in ways that are hard to escape around.
function ConvertTo-WslPath {
param([string]$WindowsPath)
if ($WindowsPath -notmatch '^([A-Za-z]):\\(.*)$') {
throw "cannot convert path to WSL: $WindowsPath"
}
return "/mnt/$($matches[1].ToLower())/$($matches[2].Replace('\','/'))"
}
$WslWorkDir = ConvertTo-WslPath $WorkDir
$WslNebulaPath = ConvertTo-WslPath $NebulaLinux
wsl -d $Distro -- bash -c "cp '$WslWorkDir/ca.crt' '$WslWorkDir/peer.crt' '$WslWorkDir/peer.key' '$WslWorkDir/peer.yml' $WslDir/ && cp '$WslNebulaPath' $WslDir/nebula && chmod +x $WslDir/nebula"
# Make sure WSL has tun support and /dev/net/tun is usable before starting
# nebula. Diagnostics first so a fail here points at the real problem (e.g.
# WSL1 distros do not have a real kernel and will not have tun).
Write-Host '=== WSL diagnostic ==='
wsl --version 2>&1 | Out-Host
wsl --list --verbose 2>&1 | Out-Host
wsl -d $Distro -u root -- uname -a | Out-Host
wsl -d $Distro -u root -- bash -c "modprobe tun 2>&1 || true; mkdir -p /dev/net; [ -c /dev/net/tun ] || mknod /dev/net/tun c 10 200; chmod 600 /dev/net/tun; ls -l /dev/net/tun"
if ($LASTEXITCODE -ne 0) { throw "failed to prepare /dev/net/tun in WSL (TUN support missing?)" }
# Deliberately no New-NetFirewallRule calls here -- nebula's windows_bypass_wdf
# feature is supposed to install WFP permit filters that let inbound traffic
# through Windows Defender Firewall on its own. If this smoke regresses, that
# feature regressed.
$lhOut = Join-Path $WorkDir 'lighthouse.out.log'
$lhErr = Join-Path $WorkDir 'lighthouse.err.log'
$lhProc = Start-Process -FilePath $Nebula -ArgumentList @('-config', "$WorkDir\lighthouse.yml") `
-PassThru -NoNewWindow `
-RedirectStandardOutput $lhOut `
-RedirectStandardError $lhErr
# Run nebula in WSL as root with no sudo + no shell wrapper. PowerShell's
# Start-Process arg quoting mangles `bash -c "..."` strings that contain
# spaces/redirections, so we skip bash entirely and let Start-Process do the
# stdout/stderr capture itself.
$peerOut = Join-Path $WorkDir 'peer.out.log'
$peerErr = Join-Path $WorkDir 'peer.err.log'
$peerProc = Start-Process -FilePath 'wsl' `
-ArgumentList @('-d', $Distro, '-u', 'root', '--', "$WslDir/nebula", '-config', "$WslDir/peer.yml") `
-PassThru -NoNewWindow `
-RedirectStandardOutput $peerOut `
-RedirectStandardError $peerErr
function Wait-Until {
param([scriptblock]$Predicate, [int]$TimeoutSec, [string]$What)
$deadline = (Get-Date).AddSeconds($TimeoutSec)
while ((Get-Date) -lt $deadline) {
if (& $Predicate) { return }
Start-Sleep -Milliseconds 500
}
throw "timed out waiting for: $What"
}
try {
Wait-Until -TimeoutSec 30 -What "windows wintun adapter $DevName with NetworkCategory=Private" -Predicate {
if ($lhProc.HasExited) { throw "lighthouse exited (code $($lhProc.ExitCode)) before tun was ready" }
$p = Get-NetConnectionProfile -InterfaceAlias $DevName -ErrorAction SilentlyContinue
$p -and ("$($p.NetworkCategory)" -ieq 'Private')
}
Write-Host "OK: $DevName NetworkCategory=Private"
Wait-Until -TimeoutSec 30 -What "WSL nebula1 with $Ip2" -Predicate {
if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before tun was ready" }
$r = wsl -d $Distro -u root -- bash -c "ip -o addr show nebula1 2>/dev/null | grep -q 'inet $Ip2' && echo yes"
("$r").Trim() -eq 'yes'
}
Write-Host "OK: WSL nebula1 has $Ip2"
Wait-Until -TimeoutSec 30 -What "ping from WSL peer to windows lighthouse ($Ip1)" -Predicate {
if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before ping succeeded" }
$r = wsl -d $Distro -u root -- bash -c "ping -c1 -W1 $Ip1 >/dev/null 2>&1 && echo OK"
("$r").Trim() -eq 'OK'
}
Write-Host "OK: WSL peer -> windows lighthouse"
Wait-Until -TimeoutSec 30 -What "ping from windows lighthouse to WSL peer ($Ip2)" -Predicate {
$null = & ping.exe -n 1 -w 1000 $Ip2
$LASTEXITCODE -eq 0
}
Write-Host "OK: windows lighthouse -> WSL peer"
Write-Host ''
Write-Host 'All smoke checks passed.'
}
catch {
Write-Host ''
Write-Host '=== lighthouse stdout ==='
Get-Content $lhOut -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== lighthouse stderr ==='
Get-Content $lhErr -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== peer stdout ==='
Get-Content $peerOut -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== peer stderr ==='
Get-Content $peerErr -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== nebula WFP filters ==='
# Dump nebula-installed filters so we can verify they got registered with
# the conditions we expect.
$wfpDump = Join-Path $WorkDir 'wfp.xml'
netsh wfp show filters file=$wfpDump 2>&1 | Out-Null
if (Test-Path $wfpDump) {
Select-String -Path $wfpDump -Pattern 'Nebula' -Context 0,80 -ErrorAction SilentlyContinue | Out-Host
}
throw
}
finally {
if (-not $lhProc.HasExited) {
Stop-Process -Id $lhProc.Id -Force -ErrorAction SilentlyContinue
$lhProc.WaitForExit(5000) | Out-Null
}
wsl -d $Distro -u root -- bash -c "pkill -f $WslDir/nebula 2>/dev/null; true" | Out-Null
# pkill returns 1 when no match and wsl propagates that; the smoke is done
# so we don't want it to leak into the script's exit code.
$global:LASTEXITCODE = 0
if ($peerProc -and -not $peerProc.HasExited) {
Stop-Process -Id $peerProc.Id -Force -ErrorAction SilentlyContinue
}
}
@@ -1,7 +1,7 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
Vagrant.configure("2") do |config|
config.vm.box = "generic/netbsd9"
config.vm.box = "DefinedNet/netbsd10"
config.vm.synced_folder "../build", "/nebula", type: "rsync"
end
+100 -98
View File
@@ -13,8 +13,8 @@ on:
- 'go.sum'
jobs:
test-linux:
name: Build all and test on ubuntu-linux
static:
name: Static checks
runs-on: ubuntu-latest
steps:
@@ -25,8 +25,16 @@ jobs:
go-version: '1.25'
check-latest: true
- name: Build
run: make all
- name: Install goimports
run: go install golang.org/x/tools/cmd/goimports@latest
- name: gofmt
run: |
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ]
then
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d
exit 1
fi
- name: Vet
run: make vet
@@ -36,87 +44,43 @@ jobs:
with:
version: v2.5
- name: Test
run: make test
- name: End 2 end
run: make e2evv
- name: Build test mobile
run: make build-test-mobile
- uses: actions/upload-artifact@v6
with:
name: e2e packet flow linux-latest
path: e2e/mermaid/linux-latest
if-no-files-found: warn
test-linux-boringcrypto:
name: Build and test on linux with boringcrypto
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build
run: make bin-boringcrypto
- name: Test
run: make test-boringcrypto
- name: End 2 end
run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
test-linux-fips140:
name: Build and test on linux with fips140=on
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build
run: make fips140
- name: Test
run: make fips140 test
- name: End 2 end
run: make fips140 e2evv
test-linux-pkcs11:
name: Build and test on linux with pkcs11
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build
run: make bin-pkcs11
- name: Test
run: make test-pkcs11
test:
name: Build and test on ${{ matrix.os }}
name: Test ${{ matrix.name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [windows-latest, macos-latest]
include:
- name: linux
os: ubuntu-latest
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
test-cmd: make test
e2e-cmd: make e2evv
- name: linux-boringcrypto
os: ubuntu-latest
build-cmd: make bin-boringcrypto
test-cmd: make test-boringcrypto
e2e-cmd: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
- name: linux-fips140
os: ubuntu-latest
build-cmd: make fips140
test-cmd: make fips140 test
e2e-cmd: make fips140 e2evv
- name: linux-pkcs11
os: ubuntu-latest
build-cmd: make bin-pkcs11
test-cmd: make test-pkcs11
e2e-cmd: ''
- name: macos
os: macos-latest
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
test-cmd: make test
e2e-cmd: make e2evv
- name: windows
os: windows-latest
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
test-cmd: make test
e2e-cmd: make e2evv
steps:
- uses: actions/checkout@v6
@@ -126,28 +90,66 @@ jobs:
go-version: '1.25'
check-latest: true
- name: Build nebula
run: go build ./cmd/nebula
- name: Build
run: ${{ matrix.build-cmd }}
- name: Build nebula-cert
run: go build ./cmd/nebula-cert
- name: Vet
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.5
- name: Cross-build darwin-amd64
if: matrix.name == 'macos'
run: GOARCH=amd64 go build -o /tmp/nebula-amd64 ./cmd/nebula && GOARCH=amd64 go build -o /tmp/nebula-cert-amd64 ./cmd/nebula-cert
- name: Test
run: make test
run: ${{ matrix.test-cmd }}
- name: End 2 end
run: make e2evv
if: matrix.e2e-cmd != ''
run: ${{ matrix.e2e-cmd }}
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
if: matrix.e2e-cmd != '' && always()
with:
name: e2e packet flow ${{ matrix.os }}
path: e2e/mermaid/${{ matrix.os }}
name: e2e packet flow ${{ matrix.name }}
path: e2e/mermaid/
if-no-files-found: warn
cross-build:
name: Cross-build ${{ matrix.name }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- {name: linux-arm, make-target: all-cross-linux-arm}
- {name: linux-mips, make-target: all-cross-linux-mips}
- {name: linux-other, make-target: all-cross-linux-other}
- {name: freebsd, make-target: all-freebsd}
- {name: openbsd, make-target: all-openbsd}
- {name: netbsd, make-target: all-netbsd}
- {name: windows, make-target: all-cross-windows}
- {name: mobile, make-target: build-test-mobile}
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build ${{ matrix.name }}
run: make -j"$(nproc)" ${{ matrix.make-target }}
finish:
name: CI status
if: always()
needs: [static, test, cross-build]
runs-on: ubuntu-latest
steps:
- name: Fail if any upstream job failed
if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled')
run: |
echo "upstream results: ${{ toJSON(needs) }}"
exit 1
- name: All upstream jobs passed
run: echo "ok"
+42 -1
View File
@@ -60,6 +60,18 @@ ALL = $(ALL_LINUX) \
windows-amd64 \
windows-arm64
# Cross-build shards used by .github/workflows/test.yml — same as ALL_*
# but with the arch that has a native CI runner removed, so the cross-build
# job is not duplicating coverage the native test jobs already give.
ALL_CROSS_LINUX = $(filter-out linux-amd64,$(ALL_LINUX))
# ALL_CROSS_LINUX further split into family sub-shards so each can run on
# its own CI runner in parallel. Union of the three must equal
# ALL_CROSS_LINUX; adding a new linux arch goes into the matching family.
ALL_CROSS_LINUX_ARM = linux-arm-5 linux-arm-6 linux-arm-7 linux-arm64
ALL_CROSS_LINUX_MIPS = linux-mips linux-mipsle linux-mips64 linux-mips64le linux-mips-softfloat
ALL_CROSS_LINUX_OTHER = linux-386 linux-ppc64le linux-riscv64 linux-loong64
e2e:
$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
@@ -82,6 +94,35 @@ DOCKER_BIN = build/linux-amd64/nebula build/linux-amd64/nebula-cert
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
all-linux: $(ALL_LINUX:%=build/%/nebula) $(ALL_LINUX:%=build/%/nebula-cert)
all-freebsd: $(ALL_FREEBSD:%=build/%/nebula) $(ALL_FREEBSD:%=build/%/nebula-cert)
all-openbsd: $(ALL_OPENBSD:%=build/%/nebula) $(ALL_OPENBSD:%=build/%/nebula-cert)
all-netbsd: $(ALL_NETBSD:%=build/%/nebula) $(ALL_NETBSD:%=build/%/nebula-cert)
all-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert build/darwin-arm64/nebula build/darwin-arm64/nebula-cert
all-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe
# CI cross-build shards. darwin-arm64 is covered by the native macos-latest
# job; windows-amd64 is covered by the native windows-latest job; both are
# omitted here to avoid building them a second time. darwin-amd64 stays in
# all-cross-darwin because intel mac is only a labeled/master-time native
# job, so PRs still need cross-build coverage for it.
all-cross-linux: $(ALL_CROSS_LINUX:%=build/%/nebula) $(ALL_CROSS_LINUX:%=build/%/nebula-cert)
all-cross-linux-arm: $(ALL_CROSS_LINUX_ARM:%=build/%/nebula) $(ALL_CROSS_LINUX_ARM:%=build/%/nebula-cert)
all-cross-linux-mips: $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula) $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula-cert)
all-cross-linux-other: $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula) $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula-cert)
all-cross-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
all-cross-windows: build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe
docker: docker/linux-$(shell go env GOARCH)
release: $(ALL:%=build/nebula-%.tar.gz)
@@ -267,5 +308,5 @@ smoke-vagrant/%: bin-docker build/%/nebula
cd .github/workflows/smoke/ && ./smoke-vagrant.sh $*
.FORCE:
.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv fips140 proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/%
.PHONY: all all-linux all-freebsd all-openbsd all-netbsd all-darwin all-windows all-cross-linux all-cross-linux-arm all-cross-linux-mips all-cross-linux-other all-cross-darwin all-cross-windows bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv fips140 proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/%
.DEFAULT_GOAL := bin
+174 -33
View File
@@ -2,24 +2,42 @@ package nebula
import (
"context"
"fmt"
"log/slog"
"math"
mathbits "math/bits"
"github.com/rcrowley/go-metrics"
)
const bitsPerWord = 64
// Bits is a sliding-window anti-replay tracker. The window is stored as a
// circular bitmap packed into uint64 words (8x denser than a []bool), so a
// length-N window costs N/8 bytes. length must be a power of two.
type Bits struct {
length uint64
lengthMask uint64
current uint64
bits []bool
bits []uint64
lostCounter metrics.Counter
dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter
}
func NewBits(bits uint64) *Bits {
func NewBits(length uint64) *Bits {
if length == 0 || length&(length-1) != 0 {
panic(fmt.Sprintf("Bits length must be a power of two, got %d", length))
}
nWords := length / bitsPerWord
if nWords == 0 {
nWords = 1
}
b := &Bits{
length: bits,
bits: make([]bool, bits, bits),
length: length,
lengthMask: length - 1,
bits: make([]uint64, nWords),
current: 0,
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
@@ -27,71 +45,194 @@ func NewBits(bits uint64) *Bits {
}
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
b.bits[0] = 1
return b
}
func (b *Bits) get(i uint64) bool {
pos := i & b.lengthMask
//bit-shifting by 6 because i is a bit index, not a u64 index, and we need to find the u64 without bit in it
return b.bits[pos>>6]&(uint64(1)<<(pos&63)) != 0
}
func (b *Bits) set(i uint64) {
pos := i & b.lengthMask
b.bits[pos>>6] |= uint64(1) << (pos & 63)
}
// clearRange clears `count` bits starting at circular position `startPos`
// (already masked to [0, length)) and returns how many of them were set
// before the clear. count must be in [1, length].
func (b *Bits) clearRange(startPos, count uint64) uint64 {
wasSet := uint64(0)
if count >= b.length {
for _, w := range b.bits {
wasSet += uint64(mathbits.OnesCount64(w))
}
clear(b.bits)
return wasSet
}
pos := startPos
remaining := count
// handle the potential partial word before pos becomes u64 aligned
word := pos >> 6
bit := pos & 63
take := uint64(64) - bit
if take > remaining {
take = remaining
}
if take > b.length-pos {
take = b.length - pos
}
var mask uint64
if take == 64 {
mask = math.MaxUint64
} else {
mask = ((uint64(1) << take) - 1) << bit
}
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
b.bits[word] &^= mask
remaining -= take
pos = (pos + take) & b.lengthMask
// Clear whole words, keeping track of the number of set bits
for remaining >= 64 {
word = pos >> 6
wasSet += uint64(mathbits.OnesCount64(b.bits[word]))
b.bits[word] = 0
remaining -= 64
pos = (pos + 64) & b.lengthMask
}
// Clear the remaining partial word
if remaining > 0 {
word = pos >> 6
mask = (uint64(1) << remaining) - 1
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
b.bits[word] &^= mask
}
return wasSet
}
func (b *Bits) strictlyWithinWindow(i uint64) bool {
// Handle the case where the window hasn't slid yet. This avoids u64 underflow.
inWarmup := b.current < b.length
if i < b.length && inWarmup {
return true
}
// Next, if the packet is in-window, see if we've seen it before
if i > b.current-b.length {
return true
}
return false //not within window!
}
// Check returns true if i is within (or way out in front of) the window, and not a replay
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
// If i is the next number, return true.
if i > b.current {
return true
}
// If i is within the window, check if it's been set already.
if i > b.current-b.length || i < b.length && b.current < b.length {
return !b.bits[i%b.length]
if b.strictlyWithinWindow(i) {
return !b.get(i)
}
// Not within the window
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("rejected a packet (top)",
"current", b.current,
"incoming", i,
)
l.Debug("rejected a packet (top)", "current", b.current, "incoming", i)
}
return false
}
// Update has three branches:
// - i == b.current+1: fast path; advance the cursor by one and lose-count
// the slot we just stomped (only past warmup; see the i > b.length guard
// below).
// - i > b.current+1: jump path; clear all slots between current and i
// (or up to a full window's worth, whichever is smaller) via clearRange,
// then mark i. Two arms here: a warmup arm that handles the very first
// window before the cursor has slid, and a steady-state arm that treats
// every cleared empty slot as a lost packet.
// - i <= b.current: in-window check for duplicates; out-of-window otherwise.
//
// NewBits seeds bits[0]=1 so counter 0 looks "received" — Update never
// clears that marker during warmup (clearRange skips position 0 when
// startPos=1), and once b.current >= b.length the marker is no longer
// consulted. The marker prevents a fictitious "lost" hit on the first real
// counter.
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
// If i is the next number, return true and update current.
// Fast path: i is the next expected counter. Split out so the function
// stays small and avoids paying for the slow paths' slog argument-build
// stack frame on every call. The bit read/test/write is inlined to
// touch the backing word once.
if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
// The very first window can only be tracked as lost once we are on the 2nd window or greater
if b.bits[i%b.length] == false && i > b.length {
pos := i & b.lengthMask
word := pos >> 6
mask := uint64(1) << (pos & 63)
w := b.bits[word]
if i > b.length && w&mask == 0 {
b.lostCounter.Inc(1)
}
b.bits[i%b.length] = true
b.bits[word] = w | mask
b.current = i
return true
}
return b.updateSlow(l, i)
}
// updateSlow handles jumps, in-window backfill, dupes, and out-of-window.
func (b *Bits) updateSlow(l *slog.Logger, i uint64) bool {
// If i is a jump, adjust the window, record lost, update current, and return true
if i > b.current {
lost := int64(0)
// Zero out the bits between the current and the new counter value, limited by the window size,
// since the window is shifting
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
end := i
if end > b.current+b.length {
end = b.current + b.length
}
count := end - b.current
startPos := (b.current + 1) & b.lengthMask
var lost int64
if b.current >= b.length {
// Steady state: every cleared slot is past warmup, so any unset
// bit we evict is a lost packet from the previous cycle.
wasSet := b.clearRange(startPos, count)
lost = int64(count) - int64(wasSet)
} else {
// Warmup (the very first window). Some cleared slots represent
// packets <= length where eviction is not "lost" in the usual
// sense. This branch is taken at most once per connection so we
// don't bother optimizing it.
for n := b.current + 1; n <= end; n++ {
if !b.get(n) && n > b.length {
lost++
}
b.bits[n%b.length] = false
}
b.clearRange(startPos, count)
}
// Only record any skipped packets as a result of the window moving further than the window length
// Any loss within the new window will be accounted for in future calls
lost += max(0, int64(i-b.current-b.length))
// Anything past the new window can never be backfilled, so it's lost.
if i > b.current+b.length {
lost += int64(i - b.current - b.length)
}
b.lostCounter.Inc(lost)
b.bits[i%b.length] = true
b.set(i)
b.current = i
return true
}
// If i is within the current window but below the current counter,
// Check to see if it's a duplicate
if i > b.current-b.length || i < b.length && b.current < b.length {
if b.current == i || b.bits[i%b.length] == true {
// If i is within the current window but below the current counter, check to see if it's a duplicate
if b.strictlyWithinWindow(i) {
pos := i & b.lengthMask
word := pos >> 6
mask := uint64(1) << (pos & 63)
w := b.bits[word]
if b.current == i || w&mask != 0 {
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window",
"accepted", false,
@@ -104,7 +245,7 @@ func (b *Bits) Update(l *slog.Logger, i uint64) bool {
return false
}
b.bits[i%b.length] = true
b.bits[word] = w | mask
return true
}
+276 -129
View File
@@ -7,61 +7,79 @@ import (
"github.com/stretchr/testify/assert"
)
// snapshot returns the bitmap as a []bool of length b.length, for readable
// test assertions against the now-packed []uint64 storage.
func (b *Bits) snapshot() []bool {
out := make([]bool, b.length)
for i := uint64(0); i < b.length; i++ {
out[i] = b.get(i)
}
return out
}
func TestBitsRequiresPowerOfTwo(t *testing.T) {
assert.Panics(t, func() { NewBits(10) })
assert.Panics(t, func() { NewBits(0) })
assert.NotPanics(t, func() { NewBits(1) })
assert.NotPanics(t, func() { NewBits(16) })
assert.NotPanics(t, func() { NewBits(1024) })
assert.NotPanics(t, func() { NewBits(16384) })
}
func TestBits(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
// make sure it is the right size
assert.Len(t, b.bits, 10)
b := NewBits(16)
assert.EqualValues(t, 16, b.length)
// This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1))
assert.True(t, b.Update(l, 1))
assert.EqualValues(t, 1, b.current)
g := []bool{true, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
g := []bool{true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// Receive two
assert.True(t, b.Check(l, 2))
assert.True(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
g = []bool{true, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
g = []bool{true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// Receive two again - it will fail
assert.False(t, b.Check(l, 2))
assert.False(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15))
assert.True(t, b.Update(l, 15))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Jump ahead to 25, which clears the window and sets slot 25%16 = 9.
assert.True(t, b.Check(l, 25))
assert.True(t, b.Update(l, 25))
assert.EqualValues(t, 25, b.current)
g = []bool{false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14))
assert.True(t, b.Update(l, 14))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 24, which is in window (current 25, length 16, window covers [10,25]).
assert.True(t, b.Check(l, 24))
assert.True(t, b.Update(l, 24))
assert.EqualValues(t, 25, b.current)
g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// Mark 5, which is not allowed because it is not in the window
// Mark 5, not allowed because 5 <= current-length (25-16=9).
assert.False(t, b.Check(l, 5))
assert.False(t, b.Update(l, 5))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
assert.EqualValues(t, 25, b.current)
g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// make sure we handle wrapping around once to the current position
b = NewBits(10)
// Make sure we handle wrapping around once to the same slot. With
// length=16, packets 1 and 17 share slot 1.
b = NewBits(16)
assert.True(t, b.Update(l, 1))
assert.True(t, b.Update(l, 11))
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
assert.True(t, b.Update(l, 17))
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}, b.snapshot())
// Walk through a few windows in order
b = NewBits(10)
b = NewBits(16)
for i := uint64(1); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i)
@@ -72,24 +90,31 @@ func TestBits(t *testing.T) {
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
// length=16. Update(55) from current=0:
// warmup, per-bit loop sees no n>16 with unset bits (slot 0 was set by
// NewBits and gets re-evaluated when n=16; n=16 is not strictly > 16),
// so the loop contributes 0. The jump exceeds the window so we record
// 55 - 0 - 16 = 39 packets fell out the back.
b := NewBits(16)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55))
assert.Equal(t, int64(39), b.lostCounter.Count())
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
// Update(100): clears 16 slots starting at slot 56%16=8. Only slot 7 (for
// packet 55) was set, so 16 - 1 = 15 evicted slots had unset bits.
// Plus 100 - 55 - 16 = 29 packets fell past the window. Total 44.
assert.True(t, b.Update(l, 100))
assert.Equal(t, int64(39+44), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
// Update(200): same shape: 16 - 1 = 15 evicted unset, plus 200 - 100 - 16 = 84 past window. Total 99.
assert.True(t, b.Update(l, 200))
assert.Equal(t, int64(39+44+99), b.lostCounter.Count())
}
func TestBitsDupeCounter(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b := NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
@@ -114,120 +139,117 @@ func TestBitsDupeCounter(t *testing.T) {
func TestBitsOutOfWindowCounter(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b := NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
// Jump to 20 (warmup branch + 4 past-window packets).
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
// 9 single-step advances, each evicts a slot whose bit was cleared during
// the jump above and whose value was never seen, so each contributes 1
// to lostCounter.
for n := uint64(21); n <= 29; n++ {
assert.True(t, b.Update(l, n))
}
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
// 0 is below current-length (29-16=13) so it falls outside the window.
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
// 4 from the Update(20) jump + 9 from 21..29.
assert.Equal(t, int64(13), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
}
func TestBitsLostCounter(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b := NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
// Walk 20..29 like the original, just with a bigger window. Same
// reasoning as TestBitsOutOfWindowCounter: 4 past-window from Update(20),
// then 9 more from the unit advances.
for n := uint64(20); n <= 29; n++ {
assert.True(t, b.Update(l, n))
}
assert.Equal(t, int64(13), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
b = NewBits(10)
b = NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(1), b.lostCounter.Count())
// Now let's fill in the window, should end up with 8 lost packets
assert.True(t, b.Update(l, 12))
assert.True(t, b.Update(l, 13))
assert.True(t, b.Update(l, 14))
// Update(15) clears the warmup window (no lost), sets slot 15.
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Update(16): slot 0 was already set (NewBits seeded it), and 16 is not
// strictly > length, so nothing is recorded as lost.
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Update(17): we jumped straight from 0 to 15, so slot 1 was cleared
// (and never re-set). 17 > 16 is past warmup, so packet 1 is recorded lost.
assert.True(t, b.Update(l, 17))
assert.True(t, b.Update(l, 18))
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(8), b.lostCounter.Count())
assert.Equal(t, int64(1), b.lostCounter.Count())
// Jump ahead by a window size
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(8), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in
assert.True(t, b.Update(l, 30))
assert.True(t, b.Update(l, 31))
assert.True(t, b.Update(l, 32))
assert.True(t, b.Update(l, 33))
assert.True(t, b.Update(l, 34))
assert.True(t, b.Update(l, 35))
assert.True(t, b.Update(l, 36))
assert.True(t, b.Update(l, 37))
assert.True(t, b.Update(l, 38))
// 39 packets tracked, 22 seen, 17 lost
assert.Equal(t, int64(17), b.lostCounter.Count())
// Fill in 18..30 in single steps. Each i evicts slot i%16. Slots 2..14
// were all cleared during Update(15), and we never re-set any of them,
// so each i in 18..30 is a fresh lost packet — 13 more.
for n := uint64(18); n <= 30; n++ {
assert.True(t, b.Update(l, n))
}
assert.Equal(t, int64(14), b.lostCounter.Count())
// Jump ahead by 2 windows, should have recording 1 full window missing
assert.True(t, b.Update(l, 58))
assert.Equal(t, int64(27), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in from this window
assert.True(t, b.Update(l, 59))
assert.True(t, b.Update(l, 60))
assert.True(t, b.Update(l, 61))
assert.True(t, b.Update(l, 62))
assert.True(t, b.Update(l, 63))
assert.True(t, b.Update(l, 64))
assert.True(t, b.Update(l, 65))
assert.True(t, b.Update(l, 66))
assert.True(t, b.Update(l, 67))
// 68 packets tracked, 32 seen, 36 missed
assert.Equal(t, int64(36), b.lostCounter.Count())
// Jump ahead by exactly one window size.
assert.True(t, b.Update(l, 46))
// end = min(46, 30+16) = 46, count = 16, all slots cleared. Before the
// jump every slot 0..15 had been set (Update(15), (16), (17), 18..30),
// so wasSet=16 and 46 == current+length means no past-window slack:
// lost contribution = 0.
assert.Equal(t, int64(14), b.lostCounter.Count())
// Walk 47..55. The Update(46) jump cleared every slot, so only slot 14
// (for packet 46) is set when we start. Each subsequent unit step lands
// on a slot that was cleared and is past warmup, so it counts as lost.
// 9 more = 23.
for n := uint64(47); n <= 55; n++ {
assert.True(t, b.Update(l, n))
}
assert.Equal(t, int64(23), b.lostCounter.Count())
// Jump ahead by two windows: clears the window plus past-window loss.
assert.True(t, b.Update(l, 87))
// current=55, length=16. end = min(87, 71) = 71. count=16, all slots
// cleared. Slots set before the clear are slots 14,15,0..7 (10 total).
// Lost from clear = 16 - 10 = 6. Past window: 87 - 55 - 16 = 16. +22.
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b := NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
// Receive 4, backfill 1, then 9, 2, 3, 5, 6, 7 (skip 8), 10, 11, 14.
// Then jump to 25 — slot 25%16=9 is being evicted, but it had been set
// (we received packet 9), so no spurious lost increment. The original
// regression was about double-counting a missing packet when its slot
// got cleared on a jump. With the jump path now using clearRange's
// word-level wasSet count, the same semantics hold.
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
@@ -244,7 +266,7 @@ func TestBitsLostCounterIssue1(t *testing.T) {
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
// Skip packet 8.
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
@@ -252,9 +274,23 @@ func TestBitsLostCounterIssue1(t *testing.T) {
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
// Jump to 25. With length=16, slot 25%16=9 corresponds to packet 9
// (which we DID receive), so its bit is set and no lost++ from that
// eviction. The trace below shows the only loss is packet 8.
assert.True(t, b.Update(l, 25))
// current was 14, i=25. end=min(25,30)=25. count=11. startPos=15.
// steady? current=14<16, so warmup branch: per-bit n=15..25, count those
// with !get(n) AND n>16. n=17..25 are >16. Among slots 17%16=1..25%16=9
// did we set slots 1..9 (packets 1..9)? Yes for all but slot 8 (packet 8
// was skipped). n=24 maps to slot 8 which is FALSE → lost++. All other
// n in 17..25 map to slots that are set. n=16 is not strictly > 16. So
// lost = 1.
assert.Equal(t, int64(1), b.lostCounter.Count())
// Fill in 12, 13, 15, 16. Each is below current=25 (in-window). 16 must
// recheck slot 0 — it was set by NewBits and then cleared by the
// Update(25) jump, so 16 backfills cleanly.
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
@@ -263,29 +299,140 @@ func TestBitsLostCounterIssue1(t *testing.T) {
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
// We missed packet 8 above and that loss is still recorded once, never
// double-counted, never zeroed.
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) {
z := NewBits(10)
for n := 0; n < b.N; n++ {
for i := range z.bits {
z.bits[i] = true
}
for i := range z.bits {
z.bits[i] = false
// TestBitsWarmupOvershoot exercises the jump path's warmup arm with an
// overshoot past one full window. NewBits leaves current=0 with only slot 0
// "set" by the marker. Jumping straight to length+k must (a) clear every
// slot the jump straddles, (b) count only past-window slack (not the
// in-window slots, which never had a "lost" tenant during warmup), and
// (c) leave the cursor at the new counter so subsequent unit advances
// count from steady state. The marker bit at slot 0 is irrelevant once
// current >= length.
func TestBitsWarmupOvershoot(t *testing.T) {
l := test.NewLogger()
b := NewBits(16)
b.lostCounter.Clear()
// Jump from current=0 to i=20 (length=16, overshoot=4).
// Warmup arm: counts slots in [1..16] where bit unset and n>length.
// Only n=16 was unset and >length: but slot 16%16=0 is the marker,
// so b.get(16) reads bits[0]=1 and skips. Result: 0 lost from the loop.
// Past-window: i - current - length = 20 - 0 - 16 = 4 lost.
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(4), b.lostCounter.Count())
assert.Equal(t, uint64(20), b.current)
// Steady state now (current=20 >= length=16). Unit advance to 21
// stomps slot 21%16=5, which was cleared by the jump and not reset,
// so this is +1 lost.
assert.True(t, b.Update(l, 21))
assert.Equal(t, int64(5), b.lostCounter.Count())
}
// TestBitsCheckAcrossWarmupBoundary pins the underflow trick in Check's
// in-window clause. While in warmup, b.current-b.length underflows uint64
// to a huge value so the first OR-clause is always false; the second
// clause (i < length && current < length) carries the in-window check.
// Once current >= length the regimes flip cleanly.
func TestBitsCheckAcrossWarmupBoundary(t *testing.T) {
l := test.NewLogger()
b := NewBits(16)
// Warmup: current=0. Check(0) must read the marker (set) and return false.
assert.False(t, b.Check(l, 0), "marker slot should look already-received")
// Warmup: any 0 < i < length is in-window and unset → accepted.
for i := uint64(1); i < 16; i++ {
assert.True(t, b.Check(l, i), "warmup in-window i=%d should be accepted", i)
}
// Warmup: i >= length but > current is "next number" so accepted.
assert.True(t, b.Check(l, 16))
assert.True(t, b.Check(l, 1_000_000))
// Cross into steady state.
assert.True(t, b.Update(l, 100))
// Now current=100, length=16. In-window range is [85..100].
// 84 is just outside: the underflow clause activates; 84 > 100-16=84 is false.
// And the warmup clause is false (current >= length). So out of window.
assert.False(t, b.Check(l, 84))
// 85 sits at the boundary. 85 > 84 is true → in window, unset → accept.
assert.True(t, b.Check(l, 85))
// 100 is current itself; not strictly greater, in-window, but already set.
assert.False(t, b.Check(l, 100))
// Way out: clearly out of window.
assert.False(t, b.Check(l, 50))
}
// TestBitsMarkerInvariant verifies the seeded bits[0]=1 marker behaves
// correctly across warmup and beyond. Update should never clear the marker
// during warmup (clearRange skips position 0 when startPos=1), and once
// current >= length the marker is no longer consulted by Check/Update on
// the live path — but it must still report counter 0 as a duplicate while
// we are in warmup.
func TestBitsMarkerInvariant(t *testing.T) {
l := test.NewLogger()
b := NewBits(8)
// Counter 0 is the seeded marker; Check sees it as already received.
assert.False(t, b.Check(l, 0))
// Update(0) at current=0 hits the duplicate branch.
b.dupeCounter.Clear()
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.dupeCounter.Count())
// Walk forward through warmup; the marker must remain set.
for n := uint64(1); n <= 7; n++ {
assert.True(t, b.Update(l, n))
}
// Position 0 (the marker) should still read as set because we never
// cleared it; Update(0) still looks like a duplicate.
assert.False(t, b.Check(l, 0))
// Cross into steady state with a unit advance to 8: pos=0, evicts the
// marker bit. The lost-counter guard (i > b.length) is false (8 == 8),
// so this advance does NOT charge a lost packet — exactly what the
// marker is there to prevent.
b.lostCounter.Clear()
assert.True(t, b.Update(l, 8))
assert.Equal(t, int64(0), b.lostCounter.Count())
// The slot at pos 0 is now occupied by counter 8.
assert.False(t, b.Check(l, 8))
}
// BenchmarkBitsUpdateInOrder is the steady-state hot path: each call is
// i == current+1.
func BenchmarkBitsUpdateInOrder(b *testing.B) {
l := test.NewLogger()
z := NewBits(16384)
for n := 0; n < b.N; n++ {
z.Update(l, uint64(n)+1)
}
}
// BenchmarkBitsUpdateReorder simulates light reorder within the window:
// every other packet arrives one slot behind its predecessor (forces the
// in-window backfill branch).
func BenchmarkBitsUpdateReorder(b *testing.B) {
l := test.NewLogger()
z := NewBits(16384)
for n := 0; n < b.N; n++ {
base := uint64(n) * 2
z.Update(l, base+2)
z.Update(l, base+1)
}
}
// BenchmarkBitsUpdateLargeJumps stresses the clearRange word-level path.
func BenchmarkBitsUpdateLargeJumps(b *testing.B) {
l := test.NewLogger()
z := NewBits(16384)
for n := 0; n < b.N; n++ {
z.Update(l, uint64(n+1)*1000)
}
}
+4
View File
@@ -217,6 +217,10 @@ func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp
return nil, err
}
if signer.Certificate.Curve() != c.Curve() {
return nil, ErrCurveMismatch
}
if signer.Certificate.Expired(now) {
return nil, ErrRootExpired
}
+28
View File
@@ -654,3 +654,31 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}
func TestCertificateV2_CurveMismatch(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.0.0.1/24")
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1}, nil, []string{"test"})
fp, _ := c.Fingerprint()
_, err = caPool.verify(c, time.Now(), fp, c.Issuer())
require.NoError(t, err)
//
c2 := c.(*certificateV2)
c2.curve = Curve_CURVE25519
fp, _ = c.Fingerprint()
_, err = caPool.verify(c, time.Now(), fp, c.Issuer())
require.Error(t, err)
}
+3
View File
@@ -112,6 +112,9 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
}
switch c.details.curve {
case Curve_CURVE25519:
if len(key) != ed25519.PublicKeySize {
return false //avoids a panic internal to ed25519
}
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
+3
View File
@@ -151,6 +151,9 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
switch c.curve {
case Curve_CURVE25519:
if len(key) != ed25519.PublicKeySize {
return false //avoids a panic internal to ed25519
}
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
+1
View File
@@ -22,6 +22,7 @@ var (
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrCertPubkeyPresent = errors.New("certificate has unexpected pubkey present")
ErrCurveMismatch = errors.New("certificate curve does not match CA")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
+10 -4
View File
@@ -13,6 +13,12 @@ import (
"golang.org/x/crypto/ed25519"
)
// testCertNow is the reference "now" used to derive default before/after times
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
// signed with default times can never expire after its CA on a rounding race.
var testCertNow = time.Now().Round(time.Second)
// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
var err error
@@ -34,10 +40,10 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
}
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}
t := &TBSCertificate{
@@ -70,11 +76,11 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}
if len(networks) == 0 {
+10 -4
View File
@@ -14,6 +14,12 @@ import (
"golang.org/x/crypto/ed25519"
)
// testCertNow is the reference "now" used to derive default before/after times
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
// signed with default times can never expire after its CA on a rounding race.
var testCertNow = time.Now().Round(time.Second)
// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
var err error
@@ -35,10 +41,10 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
}
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}
t := &cert.TBSCertificate{
@@ -71,11 +77,11 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}
var pub, priv []byte
+30 -5
View File
@@ -97,6 +97,19 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
return err
}
} else {
// out-key is meaningless under PKCS#11 because the private key never
// leaves the HSM; reject it so we never silently accept or claim a
// stdout slot for it.
outKeySet := false
cf.set.Visit(func(f *flag.Flag) {
if f.Name == "out-key" {
outKeySet = true
}
})
if outKeySet {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
}
if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
return err
@@ -171,12 +184,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
}
}
var claims ioClaims
if err := reserveOutputs(&claims,
"out-key", *cf.outKeyPath,
"out-crt", *cf.outCertPath,
"out-qr", *cf.outQRPath,
); err != nil {
return err
}
var passphrase []byte
if !isP11 && *cf.encryption {
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
errOut.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
@@ -261,15 +283,17 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
Curve: curve,
}
if !isP11 {
if !isP11 && !isStdio(*cf.outKeyPath) {
if _, err := os.Stat(*cf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
}
}
if !isStdio(*cf.outCertPath) {
if _, err := os.Stat(*cf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
}
}
var c cert.Certificate
var b []byte
@@ -294,7 +318,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
}
err = os.WriteFile(*cf.outKeyPath, b, 0600)
err = writeOutput(*cf.outKeyPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
@@ -305,7 +329,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("error while marshalling certificate: %s", err)
}
err = os.WriteFile(*cf.outCertPath, b, 0600)
err = writeOutput(*cf.outCertPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err)
}
@@ -316,7 +340,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("error while generating qr code: %s", err)
}
err = os.WriteFile(*cf.outQRPath, b, 0600)
err = writeOutput(*cf.outQRPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
@@ -332,6 +356,7 @@ func caSummary() string {
func caHelp(out io.Writer) {
cf := newCaFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n"))
out.Write([]byte(stdioHelpText))
cf.set.SetOutput(out)
cf.set.PrintDefaults()
}
+72 -7
View File
@@ -27,6 +27,7 @@ func Test_caHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" ca <flags>: create a self signed certificate authority\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -argon-iterations uint\n"+
" \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+
" -argon-memory uint\n"+
@@ -84,7 +85,7 @@ func Test_ca(t *testing.T) {
err: nil,
}
pwPromptOb := "Enter passphrase: "
pwPromptEB := "Enter passphrase: "
// required args
assertHelpError(t, ca(
@@ -168,8 +169,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.NoError(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, pwPromptEB, eb.String())
// test encrypted key with passphrase environment variable
os.Remove(keyF.Name())
@@ -207,8 +208,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, pwPromptEB, eb.String())
// test when user fails to enter a password
os.Remove(keyF.Name())
@@ -217,8 +218,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, strings.Repeat(pwPromptEB, 5), eb.String()) // prompts 5 times before giving up
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
@@ -247,3 +248,67 @@ func Test_ca(t *testing.T) {
os.Remove(keyF.Name())
}
func Test_ca_stdio(t *testing.T) {
nopw := &StubPasswordReader{}
keyF, err := os.CreateTemp("", "ca.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
crtF, err := os.CreateTemp("", "ca.crt")
require.NoError(t, err)
os.Remove(crtF.Name())
defer os.Remove(crtF.Name())
// out-crt on stdout, out-key on disk
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", keyF.Name()}, ob, eb, nopw))
assert.Empty(t, eb.String())
c, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.True(t, c.IsCA())
assert.Equal(t, "test-ca", c.Name())
// out-key on stdout, out-crt on disk
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", crtF.Name(), "-out-key", "-"}, ob, eb, nopw))
assert.Empty(t, eb.String())
_, _, curve, err := cert.UnmarshalSigningPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
// dual stdout is rejected up front
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
require.EqualError(t,
ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", "-"}, ob, eb, nopw),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
// an output conflict combined with -encrypt must error BEFORE prompting
// for a passphrase; pr would record any read attempt
tracker := &trackingPasswordReader{}
ob.Reset()
eb.Reset()
require.EqualError(t,
ca([]string{"-name", "test-ca", "-duration", "1h", "-encrypt", "-out-crt", "-", "-out-key", "-"}, ob, eb, tracker),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Zero(t, tracker.calls, "passphrase prompt should not have been called")
}
type trackingPasswordReader struct {
calls int
}
func (pr *trackingPasswordReader) ReadPassword() ([]byte, error) {
pr.calls++
return []byte(""), nil
}
+13 -2
View File
@@ -42,6 +42,8 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
return err
}
} else if *cf.outKeyPath != "" {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
return err
@@ -69,6 +71,14 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
}
}
var claims ioClaims
if err := reserveOutputs(&claims,
"out-key", *cf.outKeyPath,
"out-pub", *cf.outPubPath,
); err != nil {
return err
}
if isP11 {
p11Client, err := pkclient.FromUrl(*cf.p11url)
if err != nil {
@@ -82,12 +92,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while getting public key: %w", err)
}
} else {
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
err = writeOutput(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
}
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
err = writeOutput(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-pub: %s", err)
}
@@ -102,6 +112,7 @@ func keygenSummary() string {
func keygenHelp(out io.Writer) {
cf := newKeygenFlags()
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
_, _ = out.Write([]byte(stdioHelpText))
cf.set.SetOutput(out)
cf.set.PrintDefaults()
}
+41
View File
@@ -20,6 +20,7 @@ func Test_keygenHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -curve string\n"+
" \tECDH Curve (25519, P256) (default \"25519\")\n"+
" -out-key string\n"+
@@ -93,3 +94,43 @@ func Test_keygen(t *testing.T) {
require.NoError(t, err)
assert.Len(t, lPub, 32)
}
func Test_keygen_stdio(t *testing.T) {
keyF, err := os.CreateTemp("", "test.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
pubF, err := os.CreateTemp("", "test.pub")
require.NoError(t, err)
os.Remove(pubF.Name())
defer os.Remove(pubF.Name())
// out-pub on stdout, out-key on disk
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
require.NoError(t, keygen([]string{"-out-pub", "-", "-out-key", keyF.Name()}, ob, eb))
assert.Empty(t, eb.String())
lPub, _, curve, err := cert.UnmarshalPublicKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, lPub, 32)
// out-key on stdout, out-pub on disk
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
require.NoError(t, keygen([]string{"-out-pub", pubF.Name(), "-out-key", "-"}, ob, eb))
assert.Empty(t, eb.String())
lKey, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, lKey, 32)
// both on stdout is a conflict caught up front
ob.Reset()
eb.Reset()
require.EqualError(t, keygen([]string{"-out-pub", "-", "-out-key", "-"}, ob, eb),
`-out-key and -out-pub both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
}
+3 -1
View File
@@ -22,7 +22,9 @@ func (pr StdinPasswordReader) ReadPassword() ([]byte, error) {
}
password, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
// Terminal echo is off while reading, so the user's Enter key does not
// produce a visible newline. Emit one on stderr to match the prompt.
fmt.Fprintln(os.Stderr)
return password, err
}
+18 -3
View File
@@ -40,11 +40,23 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return err
}
rawCert, err := os.ReadFile(*pf.path)
var claims ioClaims
if err := reserveInputs(&claims, "path", *pf.path); err != nil {
return err
}
if err := reserveOutputs(&claims, "out-qr", *pf.outQRPath); err != nil {
return err
}
rawCert, err := readInput("path", *pf.path, &claims)
if err != nil {
return fmt.Errorf("unable to read cert; %s", err)
}
// When the QR is going to stdout, suppress the human-readable text/json
// output so the binary stream is not contaminated.
qrToStdout := isStdio(*pf.outQRPath)
var c cert.Certificate
var qrBytes []byte
part := 0
@@ -57,12 +69,14 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while unmarshaling cert: %s", err)
}
if !qrToStdout {
if *pf.json {
jsonCerts = append(jsonCerts, c)
} else {
_, _ = out.Write([]byte(c.String()))
_, _ = out.Write([]byte("\n"))
}
}
if *pf.outQRPath != "" {
b, err := c.MarshalPEM()
@@ -79,7 +93,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
part++
}
if *pf.json {
if *pf.json && !qrToStdout {
b, _ := json.Marshal(jsonCerts)
_, _ = out.Write(b)
_, _ = out.Write([]byte("\n"))
@@ -91,7 +105,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while generating qr code: %s", err)
}
err = os.WriteFile(*pf.outQRPath, b, 0600)
err = writeOutput(*pf.outQRPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
@@ -107,6 +121,7 @@ func printSummary() string {
func printHelp(out io.Writer) {
pf := newPrintFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n"))
out.Write([]byte(stdioHelpText))
pf.set.SetOutput(out)
pf.set.PrintDefaults()
}
+39
View File
@@ -25,6 +25,7 @@ func Test_printHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -json\n"+
" \tOptional: outputs certificates in json format\n"+
" -out-qr string\n"+
@@ -178,6 +179,44 @@ func Test_printCert(t *testing.T) {
ob.String(),
)
assert.Empty(t, eb.String())
// read cert from stdin
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-json", "-path", "-"}, ob, eb)
require.NoError(t, err)
assert.Equal(
t,
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(),
)
assert.Empty(t, eb.String())
// -out-qr - sends only the PNG to stdout, suppressing the cert dump
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-path", "-", "-out-qr", "-"}, ob, eb)
require.NoError(t, err)
assert.Empty(t, eb.String())
stdout := ob.Bytes()
require.NotEmpty(t, stdout)
// PNG magic, no PEM/JSON noise prepended
assert.Equal(t, []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}, stdout[:8])
assert.NotContains(t, string(stdout), "NebulaCertificate")
assert.NotContains(t, string(stdout), `"details"`)
// json + out-qr - still suppresses json
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-json", "-path", "-", "-out-qr", "-"}, ob, eb)
require.NoError(t, err)
assert.Empty(t, eb.String())
assert.Equal(t, []byte{0x89, 'P', 'N', 'G'}, ob.Bytes()[:4])
assert.NotContains(t, ob.String(), `"details"`)
}
// NewTestCaCert will generate a CA cert
+38 -16
View File
@@ -85,6 +85,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set both -in-pub and -out-key")
}
if isP11 && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
var v4Networks []netip.Prefix
var v6Networks []netip.Prefix
@@ -102,13 +105,35 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
if *sf.outCertPath == "" {
*sf.outCertPath = *sf.name + ".crt"
}
var claims ioClaims
if err := reserveInputs(&claims,
"ca-key", *sf.caKeyPath,
"ca-crt", *sf.caCertPath,
"in-pub", *sf.inPubPath,
); err != nil {
return err
}
if err := reserveOutputs(&claims,
"out-key", *sf.outKeyPath,
"out-crt", *sf.outCertPath,
"out-qr", *sf.outQRPath,
); err != nil {
return err
}
var curve cert.Curve
var caKey []byte
if !isP11 {
var rawCAKey []byte
rawCAKey, err := os.ReadFile(*sf.caKeyPath)
rawCAKey, err = readInput("ca-key", *sf.caKeyPath, &claims)
if err != nil {
return fmt.Errorf("error while reading ca-key: %s", err)
}
@@ -121,7 +146,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if len(passphrase) == 0 {
// ask for a passphrase until we get one
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
errOut.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if errors.Is(err, ErrNoTerminal) {
@@ -147,7 +172,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
}
rawCACert, err := os.ReadFile(*sf.caCertPath)
rawCACert, err := readInput("ca-crt", *sf.caCertPath, &claims)
if err != nil {
return fmt.Errorf("error while reading ca-crt: %s", err)
}
@@ -245,7 +270,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if *sf.inPubPath != "" {
var pubCurve cert.Curve
rawPub, err := os.ReadFile(*sf.inPubPath)
rawPub, err := readInput("in-pub", *sf.inPubPath, &claims)
if err != nil {
return fmt.Errorf("error while reading in-pub: %s", err)
}
@@ -266,17 +291,11 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
pub, rawPriv = newKeypair(curve)
}
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
if *sf.outCertPath == "" {
*sf.outCertPath = *sf.name + ".crt"
}
if !isStdio(*sf.outCertPath) {
if _, err := os.Stat(*sf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
}
}
var crts []cert.Certificate
@@ -360,11 +379,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
if !isP11 && *sf.inPubPath == "" {
if !isStdio(*sf.outKeyPath) {
if _, err := os.Stat(*sf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
}
}
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
err = writeOutput(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
@@ -379,7 +400,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
b = append(b, sb...)
}
err = os.WriteFile(*sf.outCertPath, b, 0600)
err = writeOutput(*sf.outCertPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err)
}
@@ -390,7 +411,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("error while generating qr code: %s", err)
}
err = os.WriteFile(*sf.outQRPath, b, 0600)
err = writeOutput(*sf.outQRPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
@@ -440,6 +461,7 @@ func signSummary() string {
func signHelp(out io.Writer) {
sf := newSignFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n"))
out.Write([]byte(stdioHelpText))
sf.set.SetOutput(out)
sf.set.PrintDefaults()
}
+112 -8
View File
@@ -27,6 +27,7 @@ func Test_signHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" sign <flags>: create and sign a certificate\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -ca-crt string\n"+
" \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+
" -ca-key string\n"+
@@ -376,15 +377,18 @@ func Test_signCert(t *testing.T) {
// test with the proper password
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: ", eb.String())
// test with the proper password in the environment
os.Remove(crtF.Name())
os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
ob.Reset()
eb.Reset()
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
@@ -395,8 +399,8 @@ func Test_signCert(t *testing.T) {
testpw.password = []byte("invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: ", eb.String())
// test with the wrong password in environment
ob.Reset()
@@ -416,8 +420,8 @@ func Test_signCert(t *testing.T) {
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", eb.String())
// test an error condition
ob.Reset()
@@ -425,6 +429,106 @@ func Test_signCert(t *testing.T) {
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: ", eb.String())
}
func Test_signCert_stdio(t *testing.T) {
nopw := &StubPasswordReader{
password: []byte(""),
err: nil,
}
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
rawCAKey := cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)
ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
rawCACrt, _ := ca.MarshalPEM()
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
require.NoError(t, err)
defer os.Remove(caCrtF.Name())
caCrtF.Write(rawCACrt)
caKeyF, err := os.CreateTemp("", "sign-cert.key")
require.NoError(t, err)
defer os.Remove(caKeyF.Name())
caKeyF.Write(rawCAKey)
keyF, err := os.CreateTemp("", "sign.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
// ca-key on stdin, cert to stdout
withStdin(t, bytes.NewReader(rawCAKey))
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
args := []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", keyF.Name(), "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
lCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, "stdin-test", lCrt.Name())
assert.True(t, lCrt.CheckSignature(caPub))
// two flags reading from stdin should error before any read attempt;
// otherwise an interactive shell would hang on io.ReadAll
stdinIn := bytes.NewReader(rawCAKey)
withStdin(t, stdinIn)
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", "-", "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw),
`-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
assert.Equal(t, len(rawCAKey), stdinIn.Len(), "stdin should be untouched when conflict is caught up front")
// two flags writing to stdout should error before any output is written
// AND before stdin is consumed
stdinR := bytes.NewReader(rawCAKey)
withStdin(t, stdinR)
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", "-", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
// stdin should be untouched because the conflict was caught up front
assert.Equal(t, len(rawCAKey), stdinR.Len())
// out-key on stdout, cert on disk
keyF2, err := os.CreateTemp("", "sign.key")
require.NoError(t, err)
os.Remove(keyF2.Name())
defer os.Remove(keyF2.Name())
crtF, err := os.CreateTemp("", "sign.crt")
require.NoError(t, err)
os.Remove(crtF.Name())
defer os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", "-", "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
_, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
// in-pub on stdin (caller already has a keypair, only the cert is generated)
inPub, _ := x25519Keypair()
rawInPub := cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)
withStdin(t, bytes.NewReader(rawInPub))
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "in-pub-test", "-ip", "1.1.1.1/24", "-in-pub", "-", "-out-crt", "-", "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
stdinCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, "in-pub-test", stdinCrt.Name())
assert.Equal(t, inPub, stdinCrt.PublicKey())
}
+117
View File
@@ -0,0 +1,117 @@
package main
import (
"fmt"
"io"
"os"
)
// stdioPath is the special path value that selects stdin (for inputs) or
// stdout (for outputs) instead of a file on disk.
const stdioPath = "-"
// stdioHelpText is rendered just under the Usage line of each subcommand
// help so the - convention is documented once instead of on every flag.
const stdioHelpText = " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"
// stdinReader is the source used when an input flag is set to "-".
// It is a package level var so tests can swap in a deterministic reader.
// Tests that mutate stdinReader cannot run with t.Parallel().
var stdinReader io.Reader = os.Stdin
// ioClaims tracks which flags have claimed stdin and stdout during a single
// command invocation so we can refuse a second flag asking for the same
// stream.
type ioClaims struct {
in string
out string
}
func (c *ioClaims) claimIn(flagName string) error {
if c.in != "" && c.in != flagName {
return fmt.Errorf("-%s and -%s both set to %q, only one input may read from stdin", c.in, flagName, stdioPath)
}
c.in = flagName
return nil
}
func (c *ioClaims) claimOut(flagName string) error {
if c.out != "" && c.out != flagName {
return fmt.Errorf("-%s and -%s both set to %q, only one output may write to stdout", c.out, flagName, stdioPath)
}
c.out = flagName
return nil
}
// reserveInputs walks alternating (flagName, path) pairs and claims stdin
// for any path equal to stdioPath. It must be called before any input is
// read so a conflict can be reported immediately instead of blocking on
// io.ReadAll while waiting for input that will never arrive.
func reserveInputs(claims *ioClaims, pairs ...string) error {
return reserveStdio(claims, "reserveInputs", (*ioClaims).claimIn, pairs)
}
// reserveOutputs walks alternating (flagName, path) pairs and claims stdout
// for any path equal to stdioPath. It must be called before any output is
// written so a conflict cannot leave one stream half written before the
// second flag fails.
func reserveOutputs(claims *ioClaims, pairs ...string) error {
return reserveStdio(claims, "reserveOutputs", (*ioClaims).claimOut, pairs)
}
func reserveStdio(claims *ioClaims, who string, claim func(*ioClaims, string) error, pairs []string) error {
if len(pairs)%2 != 0 {
panic(who + " requires alternating name, path pairs")
}
for i := 0; i < len(pairs); i += 2 {
name, path := pairs[i], pairs[i+1]
if path != stdioPath {
continue
}
if err := claim(claims, name); err != nil {
return err
}
}
return nil
}
// readInput returns the bytes referenced by path, reading from stdin when
// path is stdioPath.
func readInput(flagName, path string, claims *ioClaims) ([]byte, error) {
if path == stdioPath {
if err := claims.claimIn(flagName); err != nil {
return nil, err
}
return io.ReadAll(stdinReader)
}
return os.ReadFile(path)
}
// openInput returns a reader for path. When path is stdioPath the returned
// reader wraps stdin and Close is a no-op.
func openInput(flagName, path string, claims *ioClaims) (io.ReadCloser, error) {
if path == stdioPath {
if err := claims.claimIn(flagName); err != nil {
return nil, err
}
return io.NopCloser(stdinReader), nil
}
return os.Open(path)
}
// writeOutput writes data to path, or to stdout when path is stdioPath. perm
// is only used for file output. The caller must have already claimed stdout
// via reserveOutputs before invoking with stdioPath.
func writeOutput(path string, data []byte, perm os.FileMode, stdout io.Writer) error {
if path == stdioPath {
_, err := stdout.Write(data)
return err
}
return os.WriteFile(path, data, perm)
}
// isStdio reports whether path is the stdio sentinel and so should skip
// existence checks like "refuse to overwrite".
func isStdio(path string) bool {
return path == stdioPath
}
+167
View File
@@ -0,0 +1,167 @@
package main
import (
"bytes"
"io"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// withStdin temporarily replaces stdinReader for the duration of t.
func withStdin(t *testing.T, r io.Reader) {
t.Helper()
prev := stdinReader
stdinReader = r
t.Cleanup(func() { stdinReader = prev })
}
func Test_readInput_stdin(t *testing.T) {
withStdin(t, bytes.NewBufferString("hello"))
var claims ioClaims
got, err := readInput("path", "-", &claims)
require.NoError(t, err)
assert.Equal(t, []byte("hello"), got)
assert.Equal(t, "path", claims.in)
}
func Test_readInput_file(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
require.NoError(t, os.WriteFile(p, []byte("file"), 0600))
var claims ioClaims
got, err := readInput("path", p, &claims)
require.NoError(t, err)
assert.Equal(t, []byte("file"), got)
assert.Empty(t, claims.in)
}
func Test_readInput_doubleStdinErrors(t *testing.T) {
withStdin(t, bytes.NewBufferString("hello"))
var claims ioClaims
_, err := readInput("ca-key", "-", &claims)
require.NoError(t, err)
_, err = readInput("ca-crt", "-", &claims)
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
}
func Test_openInput_stdin(t *testing.T) {
withStdin(t, bytes.NewBufferString("hi"))
var claims ioClaims
r, err := openInput("ca", "-", &claims)
require.NoError(t, err)
defer r.Close()
b, err := io.ReadAll(r)
require.NoError(t, err)
assert.Equal(t, []byte("hi"), b)
}
func Test_openInput_doubleStdinErrors(t *testing.T) {
withStdin(t, bytes.NewBufferString("hi"))
var claims ioClaims
r, err := openInput("ca", "-", &claims)
require.NoError(t, err)
r.Close()
_, err = openInput("crt", "-", &claims)
require.EqualError(t, err, `-ca and -crt both set to "-", only one input may read from stdin`)
}
func Test_writeOutput_stdout(t *testing.T) {
out := &bytes.Buffer{}
err := writeOutput("-", []byte("payload"), 0600, out)
require.NoError(t, err)
assert.Equal(t, "payload", out.String())
}
func Test_writeOutput_file(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
out := &bytes.Buffer{}
err := writeOutput(p, []byte("payload"), 0600, out)
require.NoError(t, err)
assert.Empty(t, out.String())
got, err := os.ReadFile(p)
require.NoError(t, err)
assert.Equal(t, []byte("payload"), got)
}
func Test_reserveOutputs_noConflict(t *testing.T) {
var claims ioClaims
require.NoError(t, reserveOutputs(&claims,
"out-key", "/tmp/key",
"out-crt", "-",
"out-qr", "",
))
assert.Equal(t, "out-crt", claims.out)
}
func Test_reserveOutputs_conflict(t *testing.T) {
var claims ioClaims
err := reserveOutputs(&claims,
"out-key", "-",
"out-crt", "-",
)
require.EqualError(t, err, `-out-key and -out-crt both set to "-", only one output may write to stdout`)
}
func Test_reserveOutputs_panicsOnOddPairs(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
}()
var claims ioClaims
_ = reserveOutputs(&claims, "out-key")
}
func Test_reserveInputs_noConflict(t *testing.T) {
var claims ioClaims
require.NoError(t, reserveInputs(&claims,
"ca-key", "/tmp/ca.key",
"ca-crt", "-",
"in-pub", "",
))
assert.Equal(t, "ca-crt", claims.in)
}
func Test_reserveInputs_conflict(t *testing.T) {
var claims ioClaims
err := reserveInputs(&claims,
"ca-key", "-",
"ca-crt", "-",
)
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
}
func Test_claimIn_idempotent(t *testing.T) {
// pre-claim then a lazy re-claim of the same flag should be a no-op
var claims ioClaims
require.NoError(t, claims.claimIn("ca-key"))
require.NoError(t, claims.claimIn("ca-key"))
assert.Equal(t, "ca-key", claims.in)
}
func Test_claimOut_idempotent(t *testing.T) {
var claims ioClaims
require.NoError(t, claims.claimOut("out-crt"))
require.NoError(t, claims.claimOut("out-crt"))
assert.Equal(t, "out-crt", claims.out)
}
func Test_isStdio(t *testing.T) {
assert.True(t, isStdio("-"))
assert.False(t, isStdio(""))
assert.False(t, isStdio("./-"))
assert.False(t, isStdio("foo"))
}
+13 -4
View File
@@ -39,18 +39,26 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return err
}
caFile, err := os.Open(*vf.caPath)
var claims ioClaims
if err := reserveInputs(&claims,
"ca", *vf.caPath,
"crt", *vf.certPath,
); err != nil {
return err
}
caReader, err := openInput("ca", *vf.caPath, &claims)
if err != nil {
return fmt.Errorf("error while reading ca: %w", err)
}
defer caFile.Close()
defer caReader.Close()
caPool, err := cert.NewCAPoolFromPEMReader(caFile)
caPool, err := cert.NewCAPoolFromPEMReader(caReader)
if err != nil && !errors.Is(err, cert.ErrExpired) {
return fmt.Errorf("error while adding ca cert to pool: %w", err)
}
rawCert, err := os.ReadFile(*vf.certPath)
rawCert, err := readInput("crt", *vf.certPath, &claims)
if err != nil {
return fmt.Errorf("unable to read crt: %w", err)
}
@@ -85,6 +93,7 @@ func verifySummary() string {
func verifyHelp(out io.Writer) {
vf := newVerifyFlags()
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
_, _ = out.Write([]byte(stdioHelpText))
vf.set.SetOutput(out)
vf.set.PrintDefaults()
}
+44
View File
@@ -23,6 +23,7 @@ func Test_verifyHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority.\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -ca string\n"+
" \tRequired: path to a file containing one or more ca certificates\n"+
" -crt string\n"+
@@ -122,3 +123,46 @@ func Test_verify(t *testing.T) {
assert.Empty(t, eb.String())
require.NoError(t, err)
}
func Test_verify_stdio(t *testing.T) {
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
caPEM, _ := ca.MarshalPEM()
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
crtPEM, _ := crt.MarshalPEM()
caFile, err := os.CreateTemp("", "verify-ca")
require.NoError(t, err)
defer os.Remove(caFile.Name())
caFile.Write(caPEM)
// crt on stdin, ca on disk
withStdin(t, bytes.NewReader(crtPEM))
require.NoError(t, verify([]string{"-ca", caFile.Name(), "-crt", "-"}, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// ca on stdin, crt on disk
certFile, err := os.CreateTemp("", "verify-cert")
require.NoError(t, err)
defer os.Remove(certFile.Name())
certFile.Write(crtPEM)
withStdin(t, bytes.NewReader(caPEM))
ob.Reset()
eb.Reset()
require.NoError(t, verify([]string{"-ca", "-", "-crt", certFile.Name()}, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// both flags on stdin should error
withStdin(t, bytes.NewReader(caPEM))
ob.Reset()
eb.Reset()
require.EqualError(t, verify([]string{"-ca", "-", "-crt", "-"}, ob, eb),
`-ca and -crt both set to "-", only one input may read from stdin`)
}
+5 -2
View File
@@ -61,10 +61,13 @@ func main() {
}
if *configPath == "" {
fmt.Println("-config flag must be set")
flag.Usage()
p, err := config.DefaultPath()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
*configPath = p
}
c := config.NewC(l)
err := c.Load(*configPath)
+2 -15
View File
@@ -3,8 +3,6 @@ package main
import (
"fmt"
"log"
"os"
"path/filepath"
"github.com/kardianos/service"
"github.com/slackhq/nebula"
@@ -57,24 +55,13 @@ func (p *program) Stop(s service.Service) error {
return nil
}
func fileExists(filename string) bool {
_, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return true
}
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
if *configPath == "" {
ex, err := os.Executable()
p, err := config.DefaultPath()
if err != nil {
return err
}
*configPath = filepath.Dir(ex) + "/config.yaml"
if !fileExists(*configPath) {
*configPath = filepath.Dir(ex) + "/config.yml"
}
*configPath = p
}
svcConfig := &service.Config{
+5 -2
View File
@@ -50,10 +50,13 @@ func main() {
}
if *configPath == "" {
fmt.Println("-config flag must be set")
flag.Usage()
p, err := config.DefaultPath()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
*configPath = p
}
l := logging.NewLogger(os.Stdout)
+29
View File
@@ -0,0 +1,29 @@
package config
import (
"fmt"
"os"
"path/filepath"
)
// DefaultPath returns a path to a config file alongside the running executable, preferring config.yaml over config.yml.
// If neither file exists an error is returned that names both paths checked.
func DefaultPath() (string, error) {
ex, err := os.Executable()
if err != nil {
return "", err
}
return defaultPathInDir(filepath.Dir(ex))
}
func defaultPathInDir(dir string) (string, error) {
yamlPath := filepath.Join(dir, "config.yaml")
if _, err := os.Stat(yamlPath); err == nil {
return yamlPath, nil
}
ymlPath := filepath.Join(dir, "config.yml")
if _, err := os.Stat(ymlPath); err == nil {
return ymlPath, nil
}
return "", fmt.Errorf("no default config found at %s or %s", yamlPath, ymlPath)
}
+67
View File
@@ -0,0 +1,67 @@
package config
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDefaultPathInDir(t *testing.T) {
t.Run("prefers config.yaml when both exist", func(t *testing.T) {
dir := t.TempDir()
want := filepath.Join(dir, "config.yaml")
other := filepath.Join(dir, "config.yml")
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
require.NoError(t, os.WriteFile(other, []byte("a: 2"), 0644))
got, err := defaultPathInDir(dir)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("returns config.yaml when only it exists", func(t *testing.T) {
dir := t.TempDir()
want := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
got, err := defaultPathInDir(dir)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("falls back to config.yml when only it exists", func(t *testing.T) {
dir := t.TempDir()
want := filepath.Join(dir, "config.yml")
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
got, err := defaultPathInDir(dir)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("errors when neither exists and names both paths", func(t *testing.T) {
dir := t.TempDir()
got, err := defaultPathInDir(dir)
assert.Empty(t, got)
require.Error(t, err)
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yaml"))
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yml"))
})
}
func TestDefaultPath(t *testing.T) {
got, err := DefaultPath()
if err != nil {
ex, exErr := os.Executable()
require.NoError(t, exErr)
assert.Contains(t, err.Error(), filepath.Dir(ex))
return
}
ex, err := os.Executable()
require.NoError(t, err)
assert.Equal(t, filepath.Dir(ex), filepath.Dir(got))
assert.Contains(t, []string{"config.yaml", "config.yml"}, filepath.Base(got))
}
+7 -37
View File
@@ -11,7 +11,6 @@ import (
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
@@ -45,8 +44,6 @@ type connectionManager struct {
inactivityTimeout atomic.Int64
dropInactive atomic.Bool
metricsTxPunchy metrics.Counter
l *slog.Logger
}
@@ -57,7 +54,6 @@ func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p
punchy: p,
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
}
cm.reload(c, true)
@@ -369,7 +365,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if !outTraffic {
// Send a punch packet to keep the NAT state alive
cm.sendPunch(hostinfo)
cm.punchy.SendPunch(hostinfo)
}
return decision, hostinfo, primary
@@ -400,17 +396,16 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so.
cm.sendPunch(hostinfo)
cm.punchy.SendPunch(hostinfo)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
return doNothing, nil, nil
}
if cm.punchy.GetTargetEverything() {
// This is similar to the old punchy behavior with a slight optimization.
// We aren't receiving traffic but we are sending it, punch on all known
// ips in case we need to re-prime NAT state
cm.sendPunch(hostinfo)
}
// We aren't receiving traffic but we are sending it. The outbound
// traffic itself refreshes the primary remote's NAT state; this
// fans out to non-primary remotes, but only if target_all_remotes
// is configured.
cm.punchy.SendPunchToAll(hostinfo)
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Tunnel status",
@@ -512,31 +507,6 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
}
}
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
if !cm.punchy.GetPunch() {
// Punching is disabled
return
}
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
// would lose the ability to notify us and punchy.respond would become unreliable.
return
}
if cm.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, addr)
})
} else if hostinfo.remote.IsValid() {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
+4 -4
View File
@@ -64,7 +64,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Create manager
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
@@ -146,7 +146,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Create manager
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
@@ -233,7 +233,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
@@ -358,7 +358,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
// Create manager
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
ifce.connectionManager = nc
+5 -4
View File
@@ -7,13 +7,14 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/handshake"
"github.com/slackhq/nebula/noiseutil"
)
const ReplayWindow = 1024
type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
eKey noiseutil.CipherState
dKey noiseutil.CipherState
myCert cert.Certificate
peerCert *cert.CachedCertificate
initiator bool
@@ -31,8 +32,8 @@ func newConnectionStateFromResult(r *handshake.Result) *ConnectionState {
myCert: r.MyCert,
initiator: r.Initiator,
peerCert: r.RemoteCert,
eKey: NewNebulaCipherState(r.EKey),
dKey: NewNebulaCipherState(r.DKey),
eKey: noiseutil.NewCipherState(r.EKey, r.Cipher),
dKey: noiseutil.NewCipherState(r.DKey, r.Cipher),
window: NewBits(ReplayWindow),
}
ci.messageCounter.Add(r.MessageIndex)
+12 -60
View File
@@ -5,8 +5,6 @@ package nebula
import (
"net/netip"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
@@ -22,7 +20,9 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message
panic(err)
}
pipeTo.InjectUDPPacket(p)
if h.Type == msgType && h.Subtype == subType {
match := h.Type == msgType && h.Subtype == subType
p.Release()
if match {
return
}
}
@@ -38,7 +38,9 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
panic(err)
}
pipeTo.InjectUDPPacket(p)
if h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType {
match := h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType
p.Release()
if match {
return
}
}
@@ -90,65 +92,15 @@ func (c *Control) GetTunTxChan() <-chan []byte {
return c.f.inside.(*overlay.TestTun).TxPackets
}
// InjectUDPPacket will inject a packet into the udp side of nebula
// InjectUDPPacket injects a packet into the udp side. We copy internally so the caller keeps ownership of p.
// The copy comes from the freelist so steady-state alloc is zero.
func (c *Control) InjectUDPPacket(p *udp.Packet) {
c.f.outside.(*udp.TesterConn).Send(p)
c.f.outside.(*udp.TesterConn).Send(p.Copy())
}
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
serialize := make([]gopacket.SerializableLayer, 0)
var netLayer gopacket.NetworkLayer
if toAddr.Is6() {
if !fromAddr.Is6() {
panic("Cant send ipv6 to ipv4")
}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
}
udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
err := udp.SetNetworkLayerForChecksum(netLayer)
if err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
serialize = append(serialize, &udp, gopacket.Payload(data))
err = gopacket.SerializeLayers(buffer, opt, serialize...)
if err != nil {
panic(err)
}
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
// InjectTunPacket pushes an IP packet onto the tun interface.
func (c *Control) InjectTunPacket(packet []byte) {
c.f.inside.(*overlay.TestTun).Send(packet)
}
func (c *Control) GetVpnAddrs() []netip.Addr {
+82 -16
View File
@@ -11,7 +11,6 @@ import (
"sync"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/slackhq/nebula/config"
)
@@ -23,7 +22,10 @@ type dnsServer struct {
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Lite
pki *PKI
// selfHost is the cached FQDN we last seeded for ourselves
selfHost string
mux *dns.ServeMux
@@ -55,14 +57,14 @@ type dnsServer struct {
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
// watcher that tears the listener down on nebula shutdown. The returned
// pointer is always non-nil, even on error.
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*dnsServer, error) {
ds := &dnsServer{
l: l,
ctx: ctx,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
pki: pki,
}
ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest)
@@ -76,6 +78,7 @@ func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState,
if err := ds.reload(c, true); err != nil {
return ds, err
}
ds.seedSelf()
return ds, nil
}
@@ -113,7 +116,7 @@ func (d *dnsServer) reload(c *config.C, initial bool) error {
d.Stop()
}
// Drop any records that accumulated while enabled; a later re-enable
// will repopulate from fresh handshakes.
// will repopulate from fresh handshakes and a fresh seedSelf.
d.clearRecords()
return nil
}
@@ -121,17 +124,14 @@ func (d *dnsServer) reload(c *config.C, initial bool) error {
if running == nil {
// Was disabled (or never started); bring it up now.
go d.Start()
return nil
}
if sameAddr {
return nil
}
} else if !sameAddr {
d.shutdownServer(running, runningStarted, "reload")
// Old Start goroutine has now exited; bring up a fresh listener on the
// new address.
// Old Start goroutine has now exited; bring up a fresh listener on the new address.
go d.Start()
}
// Refresh the self entry every enabled reload so cert renewals that change our name or VPN addresses are picked up.
d.seedSelf()
return nil
}
@@ -249,6 +249,20 @@ func (d *dnsServer) QueryCert(data string) string {
return ""
}
// The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves.
// Answer self lookups straight from the local cert state.
if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) {
c := cs.GetDefaultCertificate()
if c == nil {
return ""
}
b, err := c.MarshalJSON()
if err != nil {
return ""
}
return string(b)
}
hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil {
return ""
@@ -266,12 +280,60 @@ func (d *dnsServer) QueryCert(data string) string {
return string(b)
}
// clearRecords drops all DNS records.
// clearRecords drops all DNS records, including the self entry.
func (d *dnsServer) clearRecords() {
d.Lock()
defer d.Unlock()
clear(d.dnsMap4)
clear(d.dnsMap6)
d.selfHost = ""
}
// seedSelf inserts (or refreshes) a record for our own cert name pointing at our VPN addresses,
// so a single-lighthouse network can resolve the lighthouse's own hostname without the two-process workaround.
func (d *dnsServer) seedSelf() {
if !d.enabled.Load() {
return
}
cs := d.certState()
if cs == nil {
return
}
c := cs.GetDefaultCertificate()
if c == nil {
return
}
newHost := strings.ToLower(c.Name()) + "."
d.Lock()
defer d.Unlock()
if d.selfHost != "" && d.selfHost != newHost {
delete(d.dnsMap4, d.selfHost)
delete(d.dnsMap6, d.selfHost)
}
d.selfHost = newHost
delete(d.dnsMap4, newHost)
delete(d.dnsMap6, newHost)
haveV4, haveV6 := false, false
for _, addr := range cs.myVpnAddrs {
if addr.Is4() && !haveV4 {
d.dnsMap4[newHost] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[newHost] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
}
func (d *dnsServer) certState() *CertState {
if d.pki == nil {
return nil
}
return d.pki.getCertState()
}
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
@@ -309,8 +371,12 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
return true
}
cs := d.certState()
if cs == nil || cs.myVpnAddrsTable == nil {
return false
}
//if we found it in this table, it's good
return d.myVpnAddrsTable.Contains(b)
return cs.myVpnAddrsTable.Contains(b)
}
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
+89
View File
@@ -9,7 +9,10 @@ import (
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -276,6 +279,92 @@ func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
}
}
// newTestPKI builds a minimal *PKI with a single v1 cert whose name and
// VPN addresses are caller-provided, suitable for exercising seedSelf and
// QueryCert self handling.
func newTestPKI(t *testing.T, name string, addrs []netip.Addr) *PKI {
t.Helper()
networks := make([]netip.Prefix, 0, len(addrs))
for _, a := range addrs {
bits := 32
if a.Is6() {
bits = 128
}
networks = append(networks, netip.PrefixFrom(a, bits))
}
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
c, _, _, _ := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, ca, caKey, name, time.Time{}, time.Time{}, networks, nil, nil)
addrsTable := new(bart.Lite)
for _, a := range addrs {
addrsTable.Insert(netip.PrefixFrom(a, a.BitLen()))
}
cs := &CertState{
v2Cert: c,
initiatingVersion: cert.Version2,
myVpnAddrs: addrs,
myVpnAddrsTable: addrsTable,
}
pki := &PKI{}
pki.cs.Store(cs)
return pki
}
func TestDnsServer_seedSelf_addsOwnRecord(t *testing.T) {
ds, c := newTestDnsServer(t)
myV4 := netip.MustParseAddr("10.0.0.1")
myV6 := netip.MustParseAddr("fd00::1")
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4, myV6})
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
ds.seedSelf()
got4, exists := ds.Query(dns.TypeA, "lighthouse.")
assert.True(t, exists)
assert.Equal(t, myV4, got4)
got6, exists := ds.Query(dns.TypeAAAA, "lighthouse.")
assert.True(t, exists)
assert.Equal(t, myV6, got6)
}
func TestDnsServer_seedSelf_disabled_noOp(t *testing.T) {
ds, c := newTestDnsServer(t)
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")})
setDnsConfig(c, "127.0.0.1", "0", true, false)
require.NoError(t, ds.reload(c, true))
ds.seedSelf()
_, exists := ds.Query(dns.TypeA, "lighthouse.")
assert.False(t, exists)
}
func TestDnsServer_clearRecords_dropsSelfHost(t *testing.T) {
ds, c := newTestDnsServer(t)
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")})
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
ds.seedSelf()
require.NotEmpty(t, ds.selfHost)
ds.clearRecords()
assert.Empty(t, ds.selfHost)
_, exists := ds.Query(dns.TypeA, "lighthouse.")
assert.False(t, exists)
}
func TestDnsServer_QueryCert_returnsOwnCert(t *testing.T) {
ds, _ := newTestDnsServer(t)
myV4 := netip.MustParseAddr("10.0.0.1")
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4})
got := ds.QueryCert(myV4.String() + ".")
assert.NotEmpty(t, got, "TXT lookup of our own VPN address should return our cert")
other := netip.MustParseAddr("10.0.0.99")
assert.Empty(t, ds.QueryCert(other.String()+"."), "unknown peer IP should return nothing")
}
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
port := freeUDPPort(t)
ds, c := newTestDnsServer(t)
+12 -12
View File
@@ -47,7 +47,7 @@ func TestHandshakeRetransmitDuplicate(t *testing.T) {
defer r.RenderFlow()
t.Log("Trigger handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
t.Log("Grab my msg1")
msg1 := myControl.GetFromUDP(true)
@@ -97,7 +97,7 @@ func TestHandshakeTruncatedPacketRecovery(t *testing.T) {
defer r.RenderFlow()
t.Log("Trigger handshake")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
t.Log("Get msg1 and deliver to responder")
msg1 := myControl.GetFromUDP(true)
@@ -146,7 +146,7 @@ func TestHandshakeOrphanedMsg2Dropped(t *testing.T) {
defer r.RenderFlow()
t.Log("Complete a normal handshake")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
r.RouteForAllUntilTxTun(theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
@@ -248,7 +248,7 @@ func TestHandshakeLateResponse(t *testing.T) {
theirControl.Start()
t.Log("Trigger handshake from me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
t.Log("Grab msg1 but don't deliver")
msg1 := myControl.GetFromUDP(true)
@@ -292,7 +292,7 @@ func TestHandshakeSelfConnectionRejected(t *testing.T) {
myControl.Start()
t.Log("Trigger handshake from me")
myControl.InjectTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
myControl.InjectTunPacket(BuildTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
msg1 := myControl.GetFromUDP(true)
t.Log("Drain any handshake retransmits before injecting")
@@ -375,7 +375,7 @@ func TestHandshakeRemoteAllowList(t *testing.T) {
defer r.RenderFlow()
t.Log("Trigger handshake from them")
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi")))
msg1 := theirControl.GetFromUDP(true)
t.Log("Rewrite the source to a blocked IP and inject")
@@ -426,7 +426,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) {
defer r.RenderFlow()
t.Log("Complete a normal handshake via the router")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
r.RouteForAllUntilTxTun(theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
@@ -437,7 +437,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) {
originalRemote := hi.CurrentRemote
t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam")))
r.RouteForAllUntilTxTun(theirControl)
t.Log("Verify tunnel still works")
@@ -475,8 +475,8 @@ func TestHandshakeWrongResponderPacketStore(t *testing.T) {
evilControl.Start()
t.Log("Send multiple packets to them (cached during handshake)")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1")))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2")))
t.Log("Route until evil tunnel is closed")
h := &header.H{}
@@ -540,7 +540,7 @@ func TestHandshakeRelayComplete(t *testing.T) {
theirControl.Start()
t.Log("Trigger handshake via relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay")))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -568,7 +568,7 @@ func TestHandshakeRelayComplete(t *testing.T) {
}
// NOTE: Relay V1 cert + IPv6 rejection is not tested here because
// InjectTunUDPPacket from a V4 node to a V6 address panics in the test
// BuildTunUDPPacket from a V4 node to a V6 address panics in the test
// framework. The check is in handshake_manager.go handleOutbound relay
// logic (lines ~304-313): if the relay host has a V1 cert and either
// address is IPv6, the relay is skipped.
+46 -30
View File
@@ -16,6 +16,7 @@ import (
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -39,11 +40,22 @@ func BenchmarkHotPath(b *testing.B) {
r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
// Pre-build the IP packet bytes once so the bench measures the data plane,
// not gopacket SerializeLayers overhead.
prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
// EnableFanIn switches the router to a 0-alloc routing path. Required
// for hot-path benchmarks; would conflict with GetFromUDP-using tests.
r.EnableFanIn()
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
myControl.InjectTunPacket(prebuilt)
// Release the TUN-side bytes back to the harness freelist; the bench
// just confirms a packet arrived, the contents aren't inspected.
overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl))
}
myControl.Stop()
@@ -71,11 +83,15 @@ func BenchmarkHotPathRelay(b *testing.B) {
theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
r.EnableFanIn()
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
myControl.InjectTunPacket(prebuilt)
overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl))
}
myControl.Stop()
@@ -97,7 +113,7 @@ func TestGoodHandshake(t *testing.T) {
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -191,7 +207,7 @@ func TestWrongResponderHandshake(t *testing.T) {
evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -273,7 +289,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -352,8 +368,8 @@ func TestStage1Race(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake to start on both me and them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true)
@@ -430,7 +446,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -441,7 +457,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")))
p = r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -480,7 +496,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -492,7 +508,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
@@ -535,7 +551,7 @@ func TestRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -565,7 +581,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -595,14 +611,14 @@ func TestReestablishRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Ensure packet traversal from them to me via the relay")
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
p = r.RouteForAllUntilTxTun(myControl)
r.Log("Assert the tunnel works")
@@ -617,7 +633,7 @@ func TestReestablishRelays(t *testing.T) {
for curIndexes >= start {
curIndexes = len(myControl.GetHostmap().Indexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
return router.RouteAndExit
@@ -634,7 +650,7 @@ func TestReestablishRelays(t *testing.T) {
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p = r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -669,7 +685,7 @@ func TestReestablishRelays(t *testing.T) {
t.Log("Assert the tunnel works the other way, too")
for {
t.Log("RouteForAllUntilTxTun")
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
p = r.RouteForAllUntilTxTun(myControl)
r.Log("Assert the tunnel works")
@@ -739,8 +755,8 @@ func TestStage1RaceRelays(t *testing.T) {
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
r.Log("Wait for a packet from them to me")
p := r.RouteForAllUntilTxTun(myControl)
@@ -787,8 +803,8 @@ func TestStage1RaceRelays2(t *testing.T) {
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@@ -852,7 +868,7 @@ func TestRehandshakingRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -957,7 +973,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -1259,8 +1275,8 @@ func TestRaceRegression(t *testing.T) {
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
t.Log("Start both handshakes")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
t.Log("Get both stage 1")
myStage1ForThem := myControl.GetFromUDP(true)
@@ -1476,7 +1492,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -1504,7 +1520,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
//reply
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")))
//wait for reply
theirControl.WaitForType(1, 0, myControl)
theirCachedPacket := myControl.GetFromTun(true)
+57 -2
View File
@@ -294,12 +294,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
controlB.InjectTunPacket(BuildTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")))
bPacket := r.RouteForAllUntilTxTun(controlA)
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
// And once more from me to them
controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
controlA.InjectTunPacket(BuildTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")))
aPacket := r.RouteForAllUntilTxTun(controlB)
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
}
@@ -408,3 +408,58 @@ func testLogLevelName() string {
}
return "info"
}
// BuildTunUDPPacket assembles an IP+UDP packet suitable for Control.InjectTunPacket.
// Using UDP here because it's a simpler protocol.
func BuildTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) []byte {
serialize := make([]gopacket.SerializableLayer, 0)
var netLayer gopacket.NetworkLayer
if toAddr.Is6() {
if !fromAddr.Is6() {
panic("Cant send ipv6 to ipv4")
}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
}
udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
if err := udp.SetNetworkLayerForChecksum(netLayer); err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
serialize = append(serialize, &udp, gopacket.Payload(data))
if err := gopacket.SerializeLayers(buffer, opt, serialize...); err != nil {
panic(err)
}
return buffer.Bytes()
}
+3 -7
View File
@@ -18,14 +18,10 @@ import (
// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain
// before failing the assertion.
//
// IgnoreCurrent is necessary in the parallelized suite: other tests can
// leave goroutines mid-shutdown when this one runs (Stop is async, the
// wg.Wait() drain is not blocking on test return). We're checking that
// *this* test's setup tears down cleanly, not that the whole suite is
// idle at this moment. Intentionally NOT t.Parallel()'d for the same
// reason — concurrent test goroutines would always show up.
// Intentionally NOT t.Parallel()'d: concurrent tests would have their own
// goroutines running and trip the assertion.
func TestNoGoroutineLeaks(t *testing.T) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
defer goleak.VerifyNone(t)
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
+188 -54
View File
@@ -13,6 +13,7 @@ import (
"regexp"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
@@ -24,6 +25,19 @@ import (
"golang.org/x/exp/maps"
)
// outNatKey is the (from, to) pair used by outNat. Comparable struct, so it works as a map key without the
// allocation cost of a string-concat key.
type outNatKey struct {
from, to netip.AddrPort
}
// fannedPacket pairs a UDP TX packet with its source control so the router can route it after popping from
// the fan-in channel.
type fannedPacket struct {
from *nebula.Control
pkt *udp.Packet
}
type R struct {
// Simple map of the ip:port registered on a control to the control
// Basically a router, right?
@@ -34,12 +48,28 @@ type R struct {
// A last used map, if an inbound packet hit the inNat map then
// all return packets should use the same last used inbound address for the outbound sender
// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
outNat map[string]netip.AddrPort
outNat map[outNatKey]netip.AddrPort
// A map of vpn ip to the nebula control it belongs to
vpnControls map[netip.Addr]*nebula.Control
// Cached select infrastructure for RouteForAllUntilTxTun.
// The controls map is immutable after NewR so the cases are good for the test lifetime.
// We only rebuild if a different receiver is asked.
selRecvCtl *nebula.Control
selCases []reflect.SelectCase
selCtls []*nebula.Control
// Optional fan-in mode for hot-path benchmarks: one forwarder goroutine per control drains UDP TX into udpFanIn,
// so RouteForAllUntilTxTun can do a fixed 2-way native select instead of paying reflect.Select per call.
// Off by default (would otherwise interleave with tests that use GetFromUDP directly on the same control).
// Enabled by EnableFanIn.
udpFanIn chan fannedPacket
stopFanIn chan struct{}
fanInWG sync.WaitGroup
fanInMu sync.Mutex
fanInOn atomic.Bool
ignoreFlows []ignoreFlow
flow []flowEntry
@@ -119,7 +149,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
controls: make(map[netip.AddrPort]*nebula.Control),
vpnControls: make(map[netip.Addr]*nebula.Control),
inNat: make(map[netip.AddrPort]*nebula.Control),
outNat: make(map[string]netip.AddrPort),
outNat: make(map[outNatKey]netip.AddrPort),
flow: []flowEntry{},
ignoreFlows: []ignoreFlow{},
fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
@@ -153,8 +183,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
case <-ctx.Done():
return
case <-clockSource.C:
r.Lock()
r.renderHostmaps("clock tick")
r.renderFlow()
r.Unlock()
}
}
}()
@@ -180,15 +212,21 @@ func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
// RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening.
func (r *R) RenderFlow() {
r.cancelRender()
r.Lock()
defer r.Unlock()
r.renderFlow()
}
// CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected
func (r *R) CancelFlowLogs() {
r.cancelRender()
r.Lock()
r.flow = nil
r.Unlock()
}
// renderFlow writes the flow log to disk. Caller must hold r.Lock. renderFlow reads r.flow / r.additionalGraphs and
// the *packet pointers stashed inside, all of which are mutated under the same lock by routing paths.
func (r *R) renderFlow() {
if r.flow == nil {
return
@@ -434,68 +472,157 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
panic("No control for udp tx " + a.String())
}
fp := r.unlockedInjectFlow(sender, c, p, false)
c.InjectUDPPacket(p)
c.InjectUDPPacket(p) // copies internally; original is ours to release
fp.WasReceived()
r.Unlock()
p.Release()
}
}
}
// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on receivers tun
// If the router doesn't have the nebula controller for that address, we panic
// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on the receiver's tun.
// If a control's UDP TX address can't be matched to a registered control, we panic.
//
// For allocation-sensitive callers (hot-path benchmarks, in particular relay
// benches with 3+ controls), call EnableFanIn() first.
func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
if r.fanInOn.Load() {
return r.routeFanIn(receiver)
}
return r.routeReflect(receiver)
}
// routeFanIn is the alloc-free path used when EnableFanIn is in effect.
func (r *R) routeFanIn(receiver *nebula.Control) []byte {
tunTx := receiver.GetTunTxChan()
for {
select {
case p := <-tunTx:
r.Lock()
if r.flow != nil {
np := udp.Packet{Data: make([]byte, len(p))}
copy(np.Data, p)
r.unlockedInjectFlow(receiver, receiver, &np, true)
}
r.Unlock()
return p
case fp := <-r.udpFanIn:
r.routeUDP(fp.from, fp.pkt)
}
}
}
// routeReflect is the default reflect.Select-based path. Pays the boxing allocation per call but doesn't interfere
// with tests that pull packets directly from controls' UDP TX channels via GetFromUDP.
func (r *R) routeReflect(receiver *nebula.Control) []byte {
sc, cm := r.selectCasesFor(receiver)
for {
x, rx, _ := reflect.Select(sc)
if x == 0 {
p := rx.Interface().([]byte)
r.Lock()
if r.flow != nil {
np := udp.Packet{Data: make([]byte, len(p))}
copy(np.Data, p)
r.unlockedInjectFlow(cm[x], cm[x], &np, true)
}
r.Unlock()
return p
}
r.routeUDP(cm[x], rx.Interface().(*udp.Packet))
}
}
// EnableFanIn switches RouteForAllUntilTxTun to the alloc-free fan-in path.
// One forwarder goroutine per registered control drains UDP TX into a shared channel that RouteForAllUntilTxTun selects
// on alongside the receiver's TUN TX channel.
func (r *R) EnableFanIn() {
r.fanInMu.Lock()
defer r.fanInMu.Unlock()
if r.fanInOn.Load() {
return
}
r.udpFanIn = make(chan fannedPacket, 32)
r.stopFanIn = make(chan struct{})
for _, c := range r.controls {
r.startFanInWorker(c)
}
r.fanInOn.Store(true)
r.t.Cleanup(r.stopFanInWorkers)
}
// startFanInWorker spawns a goroutine that drains c's UDP TX into r.udpFanIn.
func (r *R) startFanInWorker(c *nebula.Control) {
r.fanInWG.Add(1)
udpTx := c.GetUDPTxChan()
go func() {
defer r.fanInWG.Done()
for {
select {
case <-r.stopFanIn:
return
case p := <-udpTx:
select {
case <-r.stopFanIn:
p.Release()
return
case r.udpFanIn <- fannedPacket{from: c, pkt: p}:
}
}
}
}()
}
// stopFanInWorkers signals the fan-in goroutines to exit and waits for them.
func (r *R) stopFanInWorkers() {
r.fanInMu.Lock()
wasOn := r.fanInOn.Swap(false)
r.fanInMu.Unlock()
if !wasOn {
return
}
close(r.stopFanIn)
r.fanInWG.Wait()
}
// routeUDP forwards a UDP TX packet from the named source control to the destination control derived from p.To,
// releasing the source packet after InjectUDPPacket has copied its bytes into a fresh pool slot.
func (r *R) routeUDP(from *nebula.Control, p *udp.Packet) {
r.Lock()
defer r.Unlock()
a := from.GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
panic(fmt.Sprintf("No control for udp tx %s", p.To))
}
fp := r.unlockedInjectFlow(from, c, p, false)
c.InjectUDPPacket(p) // copies internally; original is ours to release
fp.WasReceived()
p.Release()
}
// selectCasesFor returns the SelectCase array used by routeReflect: one slot for the receiver's TUN TX channel followed
// by one per control's UDP TX channel. Cached for the test lifetime, only rebuilt if the receiver changes.
func (r *R) selectCasesFor(receiver *nebula.Control) ([]reflect.SelectCase, []*nebula.Control) {
r.Lock()
defer r.Unlock()
if r.selRecvCtl == receiver && r.selCases != nil {
return r.selCases, r.selCtls
}
sc := make([]reflect.SelectCase, len(r.controls)+1)
cm := make([]*nebula.Control, len(r.controls)+1)
i := 0
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(receiver.GetTunTxChan()),
Send: reflect.Value{},
}
cm[i] = receiver
i++
sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(receiver.GetTunTxChan())}
cm[0] = receiver
i := 1
for _, c := range r.controls {
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(c.GetUDPTxChan()),
Send: reflect.Value{},
}
sc[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan())}
cm[i] = c
i++
}
for {
x, rx, _ := reflect.Select(sc)
r.Lock()
if x == 0 {
// we are the tun tx, we can exit
p := rx.Interface().([]byte)
np := udp.Packet{Data: make([]byte, len(p))}
copy(np.Data, p)
r.unlockedInjectFlow(cm[x], cm[x], &np, true)
r.Unlock()
return p
} else {
// we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet)
a := cm[x].GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
r.Unlock()
panic(fmt.Sprintf("No control for udp tx %s", p.To))
}
fp := r.unlockedInjectFlow(cm[x], c, p, false)
c.InjectUDPPacket(p)
fp.WasReceived()
}
r.Unlock()
}
r.selRecvCtl = receiver
r.selCases = sc
r.selCtls = cm
return sc, cm
}
// RouteExitFunc will call the whatDo func with each udp packet from sender.
@@ -522,6 +649,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
switch e {
case ExitNow:
r.Unlock()
p.Release()
return
case RouteAndExit:
@@ -529,6 +657,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
receiver.InjectUDPPacket(p)
fp.WasReceived()
r.Unlock()
p.Release()
return
case KeepRouting:
@@ -541,6 +670,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
}
r.Unlock()
p.Release()
}
}
@@ -641,6 +771,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
switch e {
case ExitNow:
r.Unlock()
p.Release()
return
case RouteAndExit:
@@ -648,6 +779,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
receiver.InjectUDPPacket(p)
fp.WasReceived()
r.Unlock()
p.Release()
return
case KeepRouting:
@@ -659,6 +791,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
}
r.Unlock()
p.Release()
}
}
@@ -702,19 +835,20 @@ func (r *R) FlushAll() {
}
receiver.InjectUDPPacket(p)
r.Unlock()
p.Release()
}
}
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock
func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
if newAddr, ok := r.outNat[outNatKey{from: fromAddr, to: toAddr}]; ok {
p.From = newAddr
}
c, ok := r.inNat[toAddr]
if ok {
r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
r.outNat[outNatKey{from: c.GetUDPAddr(), to: fromAddr}] = toAddr
return c
}
+125
View File
@@ -0,0 +1,125 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"net"
"strings"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
func TestSSHDLifecycle(t *testing.T) {
// TestSSHDLifecycle exercises the in-process sshd through several config reloads and a Control.Stop.
ca, _, caKey, _ := cert_test.NewTestCaCert(
cert.Version1, cert.Curve_CURVE25519,
time.Now(), time.Now().Add(10*time.Minute),
nil, nil, []string{},
)
hostKeyPEM := generateSSHHostKey(t)
clientSigner, clientAuthKey := generateSSHClientKey(t)
sshdAddr := allocLoopbackPort(t)
overrides := m{
"sshd": m{
"enabled": true,
"listen": sshdAddr,
"host_key": hostKeyPEM,
"authorized_users": []m{{
"user": "tester",
"keys": []string{clientAuthKey},
}},
},
}
control, _, _, _ := newSimpleServer(cert.Version1, ca, caKey, "sshd-test", "10.222.0.1/24", overrides)
control.Start()
t.Cleanup(func() { control.Stop() })
// sshd binds in a goroutine after Start returns; wait for it.
require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
"sshd never started listening")
for i := 1; i <= 3; i++ {
out := sshExecReload(t, sshdAddr, clientSigner)
assert.Contains(t, out, "Reloading config", "reload cycle %d", i)
require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
"sshd not listening after reload cycle %d", i)
}
control.Stop()
require.Eventually(t, func() bool { return !canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
"sshd still listening after Control.Stop")
}
func canDial(addr string) bool {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err != nil {
return false
}
_ = c.Close()
return true
}
// allocLoopbackPort grabs an unused TCP port on 127.0.0.1, closes it, and returns the address. There
// is a small race between releasing the port and the sshd reclaiming it; in practice the OS keeps the
// port available long enough for the test to bind it.
func allocLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func generateSSHHostKey(t *testing.T) string {
t.Helper()
_, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
block, err := ssh.MarshalPrivateKey(priv, "nebula-e2e-host")
require.NoError(t, err)
return string(pem.EncodeToMemory(block))
}
func generateSSHClientKey(t *testing.T) (ssh.Signer, string) {
t.Helper()
_, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
signer, err := ssh.NewSignerFromKey(priv)
require.NoError(t, err)
auth := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey())))
return signer, auth
}
func sshExecReload(t *testing.T, addr string, signer ssh.Signer) string {
t.Helper()
cfg := &ssh.ClientConfig{
User: "tester",
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
client, err := ssh.Dial("tcp", addr, cfg)
require.NoError(t, err)
defer client.Close()
sess, err := client.NewSession()
require.NoError(t, err)
defer sess.Close()
// reload tears the channel down before sending exit-status, so Output returns an error on the
// channel close. The output buffer still has whatever the reload callback wrote before that.
out, _ := sess.Output("reload")
return string(out)
}
+2 -2
View File
@@ -355,14 +355,14 @@ func TestCrossStackRelaysWork(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
+30
View File
@@ -138,6 +138,14 @@ listen:
# max, net.core.rmem_max and net.core.wmem_max
#read_buffer: 10485760
#write_buffer: 10485760
# On Windows only
# When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to UDP at the listener port.
# WFP sits below Windows Defender Firewall, so this lets peer handshakes reach Nebula's outside socket regardless
# of WDF's inbound rules.
# Default true; set to false to leave WDF in charge of inbound decisions on the listener port. Not reloadable.
#windows_bypass_wdf: true
# By default, Nebula replies to packets it has no tunnel for with a "recv_error" packet. This packet helps speed up reconnection
# in the case that Nebula on either side did not shut down cleanly. This response can be abused as a way to discover if Nebula is running
# on a host though. This option lets you configure if you want to send "recv_error" packets always, never, or only to private network remotes.
@@ -163,17 +171,21 @@ listen:
punchy:
# Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
# This setting is reloadable.
punch: true
# respond means that a node you are trying to reach will connect back out to you if your hole punching fails
# this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT
# Default is false
# This setting is reloadable.
#respond: true
# delays a punch response for misbehaving NATs, default is 1 second.
# This setting is reloadable.
#delay: 1s
# set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect.
# This setting is reloadable.
#respond_delay: 5s
# Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes
@@ -282,6 +294,24 @@ tun:
# metric: 100
# install: true
# On Windows only, sets the network category of the nebula interface. Without this, Windows often
# leaves the network as "Unidentified" and treats it as Public, which makes the host firewall more
# restrictive than you usually want for an overlay between trusted peers. Valid values:
# private - treat the nebula network as a private/trusted network (default)
# public - treat it as a public/untrusted network
# domain - treat it as a domain-authenticated network
# unset - leave whatever Windows decided alone
# Not reloadable.
#network_category: private
# On Windows only
# When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to the nebula adapter LUID.
# WFP sits below Windows Defender Firewall, so this lets inbound traffic through regardless of WDF rules.
# Filters are auto-removed when the adapter goes away.
# See listen.windows_bypass_wdf for the matching control over inbound to nebula's outside UDP listener.
# Default true; set to false to leave WDF in charge of inbound decisions on the nebula interface. Not reloadable.
#windows_bypass_wdf: true
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
# in nebula configuration files. Default false, not reloadable.
#use_system_route_table: false
+34 -17
View File
@@ -59,7 +59,8 @@ type Firewall struct {
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
hasUnsafeNetworks bool
// unsafeNetworks is the list of unsafe networks issued to us in the certificate
unsafeNetworks []netip.Prefix
rules string
rulesVersion uint16
@@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
assignedNetworks = append(assignedNetworks, network)
}
hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
unsafeNetworks := c.UnsafeNetworks()
for _, n := range unsafeNetworks {
routableNetworks.Insert(n)
hasUnsafeNetworks = true
}
return &Firewall{
@@ -176,7 +176,7 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
DefaultTimeout: defaultTimeout,
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks,
unsafeNetworks: unsafeNetworks,
l: l,
incomingMetrics: firewallMetrics{
@@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
}
if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny {
flc.Any = true
return nil
}
@@ -1055,7 +1055,6 @@ func (r *rule) sanity() error {
}
func parsePort(s string) (int32, int32, error) {
var err error
const notAPort int32 = -2
if s == "any" {
return firewall.PortAny, firewall.PortAny, nil
@@ -1064,11 +1063,11 @@ func parsePort(s string) (int32, int32, error) {
return firewall.PortFragment, firewall.PortFragment, nil
}
if !strings.Contains(s, `-`) {
rPort, err := strconv.Atoi(s)
rPort, err := parsePortValue("", s)
if err != nil {
return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s)
return notAPort, notAPort, err
}
return int32(rPort), int32(rPort), nil
return rPort, rPort, nil
}
sPorts := strings.SplitN(s, `-`, 2)
@@ -1079,22 +1078,40 @@ func parsePort(s string) (int32, int32, error) {
return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
}
rStartPort, err := strconv.Atoi(sPorts[0])
startPort, err := parsePortValue("beginning range ", sPorts[0])
if err != nil {
return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
return notAPort, notAPort, err
}
rEndPort, err := strconv.Atoi(sPorts[1])
endPort, err := parsePortValue("ending range ", sPorts[1])
if err != nil {
return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
return notAPort, notAPort, err
}
startPort := int32(rStartPort)
endPort := int32(rEndPort)
if startPort == firewall.PortAny {
endPort = firewall.PortAny
}
return startPort, endPort, nil
}
// parsePortValue accepts a base-10 decimal in [0, 65535] and returns it
// widened to int32. Using strconv.ParseUint with bitSize 16 rejects
// negative input, out-of-range input (>65535), and any non-decimal byte
// by construction, so the int32 widening that follows is provably safe
// and cannot collide with firewall.PortAny (0) or firewall.PortFragment
// (-1) via integer truncation.
//
// prefix is prepended to both error messages so callers can disambiguate
// the single-port path (prefix="") from the range bounds (prefix="beginning
// range " / "ending range "), preserving the historical error strings.
func parsePortValue(prefix, s string) (int32, error) {
n, err := strconv.ParseUint(s, 10, 16)
if err == nil {
return int32(n), nil
}
if errors.Is(err, strconv.ErrRange) {
return 0, fmt.Errorf("%sout of range [0,65535]; `%s`", prefix, s)
}
return 0, fmt.Errorf("%swas not a number; `%s`", prefix, s)
}
+69
View File
@@ -1029,6 +1029,75 @@ func Test_parsePort(t *testing.T) {
require.NoError(t, err)
}
// Test_parsePort_invalid covers inputs that must error. The named bug is
// that int32(strconv.Atoi("4294967296")) truncates to 0 == firewall.PortAny,
// silently turning a typo into a match-all-ports rule; the rest are
// representative syntax/range probes.
func Test_parsePort_invalid(t *testing.T) {
tests := []struct {
name string
input string
wantErrContains string
}{
// Numeric overflow (the named bug + boundary).
{"named bug: 2^32 truncates to PortAny", "4294967296", "out of range"},
{"just above max real port", "65536", "out of range"},
// Negatives route through the range branch and hit the empty-half
// guard; included as defense in depth so a future refactor cannot
// accidentally reach the int32 cast.
{"negative", "-1", "could not be parsed"},
// Syntax probes.
{"NUL between digits", "4\x002", "was not a number"},
{"hex notation", "0x10", "was not a number"},
{"scientific notation", "1e3", "was not a number"},
{"leading whitespace", " 42", "was not a number"},
{"fullwidth digits", "42", "was not a number"},
// Range branch.
{"range upper out of range", "1-65536", "ending range out of range"},
{"range lower out of range", "65536-65537", "beginning range out of range"},
{"range with negative upper", "1--1", "ending range was not a number"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, _, err := parsePort(tc.input)
require.Error(t, err, "input %q must error", tc.input)
require.ErrorContains(t, err, tc.wantErrContains)
})
}
}
// Test_parsePort_valid_boundaries locks in success cases at 0, 1, and 65535
// so a future refactor cannot regress the boundaries.
func Test_parsePort_valid_boundaries(t *testing.T) {
tests := []struct {
name string
input string
wantStart int32
wantEnd int32
}{
{"zero is PortAny", "0", 0, 0},
{"min real port", "1", 1, 1},
{"max real port", "65535", 65535, 65535},
{"range zero to max forces end to zero", "0-65535", 0, 0},
{"range max to max", "65535-65535", 65535, 65535},
{"range one to max", "1-65535", 1, 65535},
{"range with whitespace inside", " 1 - 2 ", 1, 2},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
s, e, err := parsePort(tc.input)
require.NoError(t, err)
assert.Equal(t, tc.wantStart, s, "start port")
assert.Equal(t, tc.wantEnd, e, "end port")
})
}
}
func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger()
// Test a bad rule definition
+5 -5
View File
@@ -9,7 +9,7 @@ require (
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.26.0
github.com/gaissmai/bart v0.27.1
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4
@@ -24,12 +24,12 @@ require (
github.com/vishvananda/netlink v1.3.1
go.uber.org/goleak v1.3.0
go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.50.0
golang.org/x/crypto v0.51.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.52.0
golang.org/x/net v0.54.0
golang.org/x/sync v0.20.0
golang.org/x/sys v0.43.0
golang.org/x/term v0.42.0
golang.org/x/sys v0.44.0
golang.org/x/term v0.43.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.6.1
+10 -10
View File
@@ -26,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/gaissmai/bart v0.27.1 h1:FysPzqETMJa8q9rNkLW5peT1hq25nLOz8ksHbSVoiAk=
github.com/gaissmai/bart v0.27.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+2
View File
@@ -31,6 +31,7 @@ type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
type Result struct {
EKey *noise.CipherState
DKey *noise.CipherState
Cipher noise.CipherFunc // identifies which post-handshake CipherState the data plane should wrap EKey/DKey in
MyCert cert.Certificate
RemoteCert *cert.CachedCertificate
RemoteIndex uint32
@@ -105,6 +106,7 @@ func NewMachine(
myVersion: version,
result: &Result{
Initiator: initiator,
Cipher: cred.cipherSuite,
},
}, nil
}
+2 -146
View File
@@ -23,7 +23,6 @@ const (
DefaultHandshakeTryInterval = time.Millisecond * 100
DefaultHandshakeRetries = 10
DefaultHandshakeTriggerBuffer = 64
DefaultUseRelays = true
// maxCachedPackets is how many unsent packets we'll buffer per pending
// handshake before dropping further ones.
@@ -43,7 +42,6 @@ var (
tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries,
triggerBuffer: DefaultHandshakeTriggerBuffer,
useRelays: DefaultUseRelays,
}
)
@@ -51,7 +49,6 @@ type HandshakeConfig struct {
tryInterval time.Duration
retries int64
triggerBuffer int
useRelays bool
messageMetrics *MessageMetrics
}
@@ -86,6 +83,7 @@ type HandshakeHostInfo struct {
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
lastRelays []netip.Addr // Relays we attempted to use during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo
@@ -220,7 +218,6 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
fields := []any{
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
"initiatorIndex", hh.hostinfo.localIndexId,
"remoteIndex", hh.hostinfo.remoteIndexId,
"durationNs", time.Since(hh.startTime).Nanoseconds(),
}
// hh.machine can be nil here if buildStage0Packet never succeeded
@@ -326,146 +323,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
)
}
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to
if relay == vpnIp {
continue
}
// Don't relay to myself
if hm.f.myVpnAddrsTable.Contains(relay) {
continue
}
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String())
hm.f.Handshake(relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through
existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
if !ok {
// No relays exist or requested yet.
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
}
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.Info("send CreateRelayRequest",
"relayFrom", hm.f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", idx,
"relay", relay,
)
}
}
continue
}
switch existingRelay.State {
case Established:
hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[handshakePacketStage0], make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested:
hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String())
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.Info("send CreateRelayRequest",
"relayFrom", hm.f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", existingRelay.LocalIndex,
"relay", relay,
)
}
case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default:
hostinfo.logger(hm.l).Error("Relay unexpected state",
"vpnIp", vpnIp,
"state", existingRelay.State,
"relay", relay,
)
}
}
}
hm.f.relayManager.StartRelays(hm.f, vpnIp, hh, stage0)
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered {
@@ -607,7 +465,6 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs,
)
}
@@ -630,7 +487,6 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs,
)
}
+14
View File
@@ -174,6 +174,10 @@ func (h *H) SubTypeName() string {
return SubTypeName(h.Type, h.Subtype)
}
func (h *H) IsValidSubType() bool {
return IsValidSubType(h.Type, h.Subtype)
}
// SubTypeName will transform a nebula message sub type into a human string
func SubTypeName(t MessageType, s MessageSubType) string {
if n, ok := subTypeMap[t]; ok {
@@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string {
return "unknown"
}
func IsValidSubType(t MessageType, s MessageSubType) bool {
if n, ok := subTypeMap[t]; ok {
if _, ok := (*n)[s]; ok {
return true
}
}
return false
}
// NewHeader turns bytes into a header
func NewHeader(b []byte) (*H, error) {
h := new(H)
-1
View File
@@ -391,7 +391,6 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
"error", err,
"udpAddr", remote,
"counter", c,
"attemptedCounter", c,
)
return
}
+27 -8
View File
@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
@@ -14,6 +15,7 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
}
func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
cs := f.pki.getCertState()
curCert := cs.getCertificate(cert.Version2)
if curCert == nil {
curCert = cs.getCertificate(cert.Version1)
}
// The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction.
// Check to see if that set has changed, and if so, rebuild the firewall.
certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks)
if !c.HasChanged("firewall") && !certUnsafeChanged {
f.l.Debug("No firewall config change detected")
return
}
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
fw, err := NewFirewallFromConfig(f.l, cs, c)
if err != nil {
f.l.Error("Error while creating firewall during reload", "error", err)
return
@@ -491,11 +502,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
emit := func() {
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
@@ -512,6 +519,18 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
certMaxVersion.Update(int64(certState.v1Cert.Version()))
}
}
// Prime gauges so a Prometheus scrape that lands before the first tick
// sees real values instead of the zero defaults (issue #907).
emit()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
emit()
}
}
}
+73
View File
@@ -0,0 +1,73 @@
//go:build linux || darwin
package nebula
import (
"context"
"net/netip"
"testing"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/overlay/overlaytest"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that
// landed before the first ticker fire used to read 0 for the cert gauges.
// emitStats now primes the gauges before entering the ticker loop. We assert
// the gauge is zero before the first call and non-zero after.
func Test_emitStats_primesGauges(t *testing.T) {
defer metrics.DefaultRegistry.UnregisterAll()
l := test.NewLogger()
hostMap := newHostMap(l)
preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
hostMap.preferredRanges.Store(&preferredRanges)
notAfter := time.Now().Add(time.Hour)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter},
v1Credential: nil,
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
// On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to
// *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn,
// returns an error, and the emitter falls through to a no-op.
writers: []udp.Conn{&udp.StdConn{}},
}
ifce.pki.cs.Store(cs)
ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs")
// Pre-cancel the context so emitStats returns after priming the gauges
// without ever reading from ticker.C. The one hour interval is just a
// belt-and-suspenders, the test does not expect the ticker to fire.
ctx, cancel := context.WithCancel(context.Background())
cancel()
ifce.emitStats(ctx, time.Hour)
ttl := ttlGauge.Value()
assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick")
assert.LessOrEqual(t, ttl, int64(3600))
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value())
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value())
}
+120
View File
@@ -0,0 +1,120 @@
package nebula
import (
"net/netip"
"testing"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall
// rebuilds the firewall when only the certificate's UnsafeNetworks have changed,
// even if the firewall section of the YAML has not.
func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) {
l := test.NewLogger()
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
// dummyCert avoids dragging the real signing pipeline into a unit test.
c1 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: initialUnsafe,
}
pki := &PKI{}
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
rawYAML := `firewall:
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
`
cfg := config.NewC(l)
require.NoError(t, cfg.LoadString(rawYAML))
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
require.NoError(t, err)
require.Equal(t, initialUnsafe, fw.unsafeNetworks)
f := &Interface{
pki: pki,
firewall: fw,
l: l,
}
// Swap the cert with a different UnsafeNetworks set.
newUnsafe := []netip.Prefix{
netip.MustParsePrefix("198.51.100.0/24"),
netip.MustParsePrefix("203.0.113.0/24"),
}
c2 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: newUnsafe,
}
pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2})
// Reload with the same YAML so HasChanged("firewall") reports false.
require.NoError(t, cfg.ReloadConfigString(rawYAML))
require.False(t, cfg.HasChanged("firewall"))
f.reloadFirewall(cfg)
assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced")
assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks)
assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5")))
}
// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when
// neither the firewall config nor the cert's UnsafeNetworks have changed.
func TestReloadFirewall_NoChange(t *testing.T) {
l := test.NewLogger()
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
c1 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: unsafe,
}
pki := &PKI{}
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
rawYAML := `firewall:
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
`
cfg := config.NewC(l)
require.NoError(t, cfg.LoadString(rawYAML))
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
require.NoError(t, err)
f := &Interface{
pki: pki,
firewall: fw,
l: l,
}
require.NoError(t, cfg.ReloadConfigString(rawYAML))
f.reloadFirewall(cfg)
assert.Same(t, fw, f.firewall, "firewall should not have been replaced")
}
+25 -48
View File
@@ -15,7 +15,6 @@ import (
"time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
@@ -35,7 +34,6 @@ type LightHouse struct {
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Lite
punchConn udp.Conn
punchy *Punchy
// Local cache of answers from light houses
@@ -76,7 +74,6 @@ type LightHouse struct {
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *slog.Logger
}
@@ -105,7 +102,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
myVpnNetworksTable: cs.myVpnNetworksTable,
addrMap: make(map[netip.Addr]*RemoteList),
nebulaPort: nebulaPort,
punchConn: pc,
punchy: p,
updateTrigger: make(chan struct{}, 1),
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
@@ -118,9 +114,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
if c.GetBool("stats.lighthouse_metrics", false) {
h.metrics = newLighthouseMetrics()
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
} else {
h.metricHolepunchTx = metrics.NilCounter{}
}
err := h.reload(c, true)
@@ -279,16 +272,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
// Clean up. Entries still in the static_host_map will be re-built.
// Entries no longer present must have their (possible) background DNS goroutines stopped.
if existingStaticList := lh.staticList.Load(); existingStaticList != nil {
ourselves := lh.myVpnNetworks[0].Addr()
oldStaticList := lh.staticList.Load()
if oldStaticList != nil {
lh.RLock()
for staticVpnAddr := range *existingStaticList {
for staticVpnAddr := range *oldStaticList {
if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil {
am.hr.Cancel()
am.ResetForOwner(ourselves)
}
}
lh.RUnlock()
}
// Build a new list based on current config.
staticList := make(map[netip.Addr]struct{})
err := lh.loadStaticMap(c, staticList)
@@ -296,6 +291,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return err
}
// For entries removed from static_host_map, stop the DNS goroutine and drop the cached addrs.
// All addrs must come from the lighthouses now that it's no longer a static host.
if oldStaticList != nil {
lh.RLock()
for staticVpnAddr := range *oldStaticList {
if _, stillStatic := staticList[staticVpnAddr]; stillStatic {
continue
}
if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil {
am.ClearHostnameResults()
}
}
lh.RUnlock()
}
lh.staticList.Store(&staticList)
if !initial {
if c.HasChanged("static_host_map") {
@@ -1406,58 +1416,25 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
return
}
empty := []byte{0}
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
if !vpnPeer.IsValid() {
return
}
go func() {
time.Sleep(lhh.lh.punchy.GetDelay())
lhh.lh.metricHolepunchTx.Inc(1)
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
}()
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Punching",
"vpnPeer", vpnPeer,
"logVpnAddr", logVpnAddr,
)
}
}
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts {
b := protoV4AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
lhh.lh.punchy.Schedule(b, detailsVpnAddr)
}
}
for _, a := range n.Details.V6AddrPorts {
b := protoV6AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
lhh.lh.punchy.Schedule(b, detailsVpnAddr)
}
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lhh.lh.punchy.GetRespond() {
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Sending a nebula test packet",
"vpnAddr", detailsVpnAddr,
)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
}
// a tunnel. ScheduleRespond is a no-op when punchy.respond is disabled.
lhh.lh.punchy.ScheduleRespond(detailsVpnAddr)
}
func protoAddrToNetAddr(addr *Addr) netip.Addr {
+126
View File
@@ -303,6 +303,132 @@ func TestLighthouse_reload(t *testing.T) {
require.NoError(t, err)
}
// TestLighthouse_reloadStaticHostMap verifies that reloading static_host_map applies the new
// config rather than appending to it. See issue #718.
func TestLighthouse_reloadStaticHostMap(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
"10.128.0.2": []any{"1.1.1.1:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
require.NoError(t, err)
staticHost := netip.MustParseAddr("10.128.0.2")
otherHost := netip.MustParseAddr("10.128.0.3")
// Capture the RemoteList pointer up front; an in-flight handshake would hold the same one
// on hostinfo.remotes, so it must reflect every reload below.
pinned := lh.Query(staticHost)
require.NotNil(t, pinned)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, pinned.CopyAddrs([]netip.Prefix{}))
// Replace the remote address. The new address should be the only entry.
nc := map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"2.2.2.2:4242"},
},
}
rc, err := yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
rl := lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl, "RemoteList pointer must stay stable so in-flight handshakes pick up the change")
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Reload back to the original IP. Mirrors the round-trip in issue #718 step 6-8 where
// the buggy reload produced [1.1.1.1, 2.2.2.2, 1.1.1.1] instead of [1.1.1.1].
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"1.1.1.1:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Reload with the same config. An unchanged entry must not duplicate.
require.NoError(t, c.ReloadConfigString(string(rc)))
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Switch back to 2.2.2.2 so the rest of the test continues against a known address.
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"2.2.2.2:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
// Add a second host alongside the first. Both should be present, neither duplicated.
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"2.2.2.2:4242"},
"10.128.0.3": []any{"3.3.3.3:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl, "adding a sibling entry must not displace the existing RemoteList")
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{}))
rl = lh.Query(otherHost)
require.NotNil(t, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Drop the first host entirely. The vpnAddr is no longer marked static, our owner
// contribution is cleared, but the addrMap entry stays in place so non-static cache
// data (from lighthouse queries) on the same RemoteList isn't lost. In-flight handshakes
// that already had the pointer see an empty address list rather than retrying stale ones.
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.3": []any{"3.3.3.3:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
_, isStatic := lh.GetStaticHostList()[staticHost]
assert.False(t, isStatic)
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl)
assert.Empty(t, rl.CopyAddrs([]netip.Prefix{}))
rl = lh.Query(otherHost)
require.NotNil(t, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{}))
}
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
+5 -7
View File
@@ -55,7 +55,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
}
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
ssh, err := sshd.NewSSHServer(ctx, l.With("subsystem", "sshd"))
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
}
@@ -170,7 +170,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
}
hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c, udpConns[0])
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil {
@@ -184,21 +184,17 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
messageMetrics = newMessageMetricsOnlyRecvError()
}
useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false)
handshakeConfig := HandshakeConfig{
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
useRelays: useRelays,
messageMetrics: messageMetrics,
}
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
ds, err := newDnsServerFromConfig(ctx, l, pki, hostMap, c)
if err != nil {
l.Warn("Failed to start DNS responder", "error", err)
}
@@ -244,6 +240,8 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
handshakeManager.f = ifce
go handshakeManager.Run(ctx)
punchy.Start(ctx, ifce, hostMap, lightHouse)
}
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
+8
View File
@@ -13,6 +13,8 @@ type MessageMetrics struct {
rxUnknown metrics.Counter
txUnknown metrics.Counter
rxInvalid metrics.Counter
}
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
@@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int
}
}
}
func (m *MessageMetrics) RxInvalid(i int64) {
if m != nil && m.rxInvalid != nil {
m.rxInvalid.Inc(i)
}
}
func newMessageMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
@@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics {
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil),
}
}
-88
View File
@@ -1,88 +0,0 @@
package nebula
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
type endianness interface {
PutUint64(b []byte, v uint64)
}
var noiseEndianness endianness = binary.BigEndian
type NebulaCipherState struct {
c cipher.AEAD
}
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
x := s.Cipher()
return &NebulaCipherState{c: x.(cipher.AEAD)}
}
type cipherAEADDanger interface {
EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error)
DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error)
}
// EncryptDanger encrypts and authenticates a given payload.
//
// out is a destination slice to hold the output of the EncryptDanger operation.
// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
// - plaintext is encrypted, authenticated and appended to out.
// - n is a nonce value which must never be re-used with this key.
// - nb is a buffer used for temporary storage in the implementation of this call, which should
// be re-used by callers to minimize garbage collection.
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
switch ce := s.c.(type) {
case cipherAEADDanger:
return ce.EncryptDanger(out, ad, plaintext, n, nb)
default:
// TODO: Is this okay now that we have made messageCounter atomic?
// Alternative may be to split the counter space into ranges
//if n <= s.n {
// return nil, errors.New("CRITICAL: a duplicate counter value was used")
//}
//s.n = n
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
out = s.c.Seal(out, nb, plaintext, ad)
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
return out, nil
}
} else {
return nil, errors.New("no cipher state available to encrypt")
}
}
func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
switch ce := s.c.(type) {
case cipherAEADDanger:
return ce.DecryptDanger(out, ad, ciphertext, n, nb)
default:
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
return s.c.Open(out, nb, ciphertext, ad)
}
} else {
return []byte{}, nil
}
}
func (s *NebulaCipherState) Overhead() int {
if s != nil {
return s.c.Overhead()
}
return 0
}
+53
View File
@@ -0,0 +1,53 @@
package noiseutil
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
// CipherStateAESGCM is the data-plane wrapper for the AES-GCM AEAD cipher.
// AES-GCM uses big-endian nonce encoding per the Noise spec.
type CipherStateAESGCM struct {
c cipher.AEAD
}
// NewCipherStateAESGCM extracts the underlying AEAD from the post-handshake noise.CipherState.
// The caller is responsible for ensuring the noise cipher is actually AES-GCM,
// otherwise the type assertion still succeeds but the nonce endianness will be wrong on the wire.
func NewCipherStateAESGCM(s *noise.CipherState) *CipherStateAESGCM {
return &CipherStateAESGCM{c: s.Cipher().(cipher.AEAD)}
}
func (s *CipherStateAESGCM) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return nil, errors.New("no cipher state available to encrypt")
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.BigEndian.PutUint64(nb[4:], n)
return s.c.Seal(out, nb, plaintext, ad), nil
}
func (s *CipherStateAESGCM) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return []byte{}, nil
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.BigEndian.PutUint64(nb[4:], n)
return s.c.Open(out, nb, ciphertext, ad)
}
func (s *CipherStateAESGCM) Overhead() int {
if s == nil {
return 0
}
return s.c.Overhead()
}
+52
View File
@@ -0,0 +1,52 @@
package noiseutil
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
// CipherStateChaChaPoly is the data-plane wrapper for the ChaCha20-Poly1305 AEAD cipher.
// ChaCha20-Poly1305 uses little-endian nonce encoding per the Noise spec.
type CipherStateChaChaPoly struct {
c cipher.AEAD
}
// NewCipherStateChaChaPoly extracts the underlying AEAD from the post-handshake noise.CipherState.
// The caller is responsible for ensuring the noise cipher is actually ChaCha20-Poly1305.
func NewCipherStateChaChaPoly(s *noise.CipherState) *CipherStateChaChaPoly {
return &CipherStateChaChaPoly{c: s.Cipher().(cipher.AEAD)}
}
func (s *CipherStateChaChaPoly) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return nil, errors.New("no cipher state available to encrypt")
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.LittleEndian.PutUint64(nb[4:], n)
return s.c.Seal(out, nb, plaintext, ad), nil
}
func (s *CipherStateChaChaPoly) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return []byte{}, nil
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.LittleEndian.PutUint64(nb[4:], n)
return s.c.Open(out, nb, ciphertext, ad)
}
func (s *CipherStateChaChaPoly) Overhead() int {
if s == nil {
return 0
}
return s.c.Overhead()
}
+40
View File
@@ -0,0 +1,40 @@
package noiseutil
import (
"fmt"
"github.com/flynn/noise"
)
// CipherState is the post-handshake AEAD cipher used for the data plane.
// Each supported cipher has its own concrete implementation in this package with the nonce endianness hardcoded,
// so the encrypt/decrypt fast path avoids interface dispatch on the byte order.
type CipherState interface {
// EncryptDanger encrypts and authenticates a given payload.
//
// out is a destination slice to hold the output of the EncryptDanger operation.
// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
// - plaintext is encrypted, authenticated and appended to out.
// - n is a nonce value which must never be re-used with this key.
// - nb is a scratch buffer used to assemble the nonce.
EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error)
// DecryptDanger authenticates and decrypts a given payload, with the same argument shape as EncryptDanger.
DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error)
// Overhead returns the AEAD tag size, or 0 if the receiver is nil.
Overhead() int
}
// NewCipherState wraps the post-handshake noise.CipherState in the per-cipher type that matches cipherFunc.
// cipherFunc must be the same cipher used to build the noise CipherSuite that produced s.
func NewCipherState(s *noise.CipherState, cipherFunc noise.CipherFunc) CipherState {
switch cipherFunc.CipherName() {
case CipherAESGCM.CipherName():
return NewCipherStateAESGCM(s)
case noise.CipherChaChaPoly.CipherName():
return NewCipherStateChaChaPoly(s)
default:
panic(fmt.Sprintf("noiseutil: unsupported cipher %q", cipherFunc.CipherName()))
}
}
+166
View File
@@ -0,0 +1,166 @@
package noiseutil
import (
"testing"
"github.com/flynn/noise"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCipherStateAESGCMRoundtrip(t *testing.T) {
enc, dec := buildCipherStates(t, CipherAESGCM)
roundtrip(t, NewCipherStateAESGCM(enc), NewCipherStateAESGCM(dec))
}
func TestCipherStateChaChaPolyRoundtrip(t *testing.T) {
enc, dec := buildCipherStates(t, noise.CipherChaChaPoly)
roundtrip(t, NewCipherStateChaChaPoly(enc), NewCipherStateChaChaPoly(dec))
}
func TestNewCipherStateDispatch(t *testing.T) {
encA, _ := buildCipherStates(t, CipherAESGCM)
encC, _ := buildCipherStates(t, noise.CipherChaChaPoly)
assert.IsType(t, &CipherStateAESGCM{}, NewCipherState(encA, CipherAESGCM))
assert.IsType(t, &CipherStateChaChaPoly{}, NewCipherState(encC, noise.CipherChaChaPoly))
}
func TestNewCipherStateUnsupportedPanics(t *testing.T) {
enc, _ := buildCipherStates(t, CipherAESGCM)
assert.Panics(t, func() {
NewCipherState(enc, fakeCipher{})
})
}
type fakeCipher struct{}
func (fakeCipher) Cipher(k [32]byte) noise.Cipher { return nil }
func (fakeCipher) CipherName() string { return "Fake" }
// buildCipherStates runs an in-memory NN handshake with the requested cipher
// to produce a pair of post-handshake CipherStates that share keys.
func buildCipherStates(t *testing.T, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) {
t.Helper()
suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256)
cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN}
cfg.Initiator = true
hsI, err := noise.NewHandshakeState(cfg)
require.NoError(t, err)
cfg.Initiator = false
hsR, err := noise.NewHandshakeState(cfg)
require.NoError(t, err)
msg, _, _, err := hsI.WriteMessage(nil, nil)
require.NoError(t, err)
_, _, _, err = hsR.ReadMessage(nil, msg)
require.NoError(t, err)
msg, dR, _, err := hsR.WriteMessage(nil, nil)
require.NoError(t, err)
_, eI, _, err := hsI.ReadMessage(nil, msg)
require.NoError(t, err)
require.NotNil(t, eI)
require.NotNil(t, dR)
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
return eI, dR
}
func roundtrip(t *testing.T, enc, dec CipherState) {
t.Helper()
plaintext := []byte("nebula cipher state roundtrip")
ad := []byte("aad")
nb := make([]byte, 12)
ct, err := enc.EncryptDanger(nil, ad, plaintext, 1, nb)
require.NoError(t, err)
assert.NotEqual(t, plaintext, ct)
pt, err := dec.DecryptDanger(nil, ad, ct, 1, nb)
require.NoError(t, err)
assert.Equal(t, plaintext, pt)
// Wrong nonce must fail authentication.
_, err = dec.DecryptDanger(nil, ad, ct, 2, nb)
require.Error(t, err)
assert.Equal(t, enc.Overhead(), dec.Overhead())
assert.Equal(t, 16, enc.Overhead())
}
func BenchmarkCipherStateEncryptAESGCM(b *testing.B) {
enc, _ := buildCipherStatesB(b, CipherAESGCM)
benchEncryptCipherState(b, NewCipherState(enc, CipherAESGCM))
}
func BenchmarkCipherStateEncryptChaChaPoly(b *testing.B) {
enc, _ := buildCipherStatesB(b, noise.CipherChaChaPoly)
benchEncryptCipherState(b, NewCipherState(enc, noise.CipherChaChaPoly))
}
func benchEncryptCipherState(b *testing.B, cs CipherState) {
plaintext := make([]byte, 1280)
ad := make([]byte, 16)
nb := make([]byte, 12)
out := make([]byte, 0, len(plaintext)+cs.Overhead())
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
var err error
out, err = cs.EncryptDanger(out[:0], ad, plaintext, uint64(i+1), nb)
if err != nil {
b.Fatal(err)
}
}
}
func buildCipherStatesB(b *testing.B, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) {
b.Helper()
suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256)
cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN}
cfg.Initiator = true
hsI, err := noise.NewHandshakeState(cfg)
if err != nil {
b.Fatal(err)
}
cfg.Initiator = false
hsR, err := noise.NewHandshakeState(cfg)
if err != nil {
b.Fatal(err)
}
msg, _, _, err := hsI.WriteMessage(nil, nil)
if err != nil {
b.Fatal(err)
}
if _, _, _, err := hsR.ReadMessage(nil, msg); err != nil {
b.Fatal(err)
}
msg, dR, _, err := hsR.WriteMessage(nil, nil)
if err != nil {
b.Fatal(err)
}
_, eI, _, err := hsI.ReadMessage(nil, msg)
if err != nil {
b.Fatal(err)
}
return eI, dR
}
func TestCipherStateNilSafety(t *testing.T) {
var aes *CipherStateAESGCM
_, err := aes.EncryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.Error(t, err)
out, err := aes.DecryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.NoError(t, err)
assert.Empty(t, out)
assert.Equal(t, 0, aes.Overhead())
var cc *CipherStateChaChaPoly
_, err = cc.EncryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.Error(t, err)
out, err = cc.DecryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.NoError(t, err)
assert.Empty(t, out)
assert.Equal(t, 0, cc.Overhead())
}
+126 -164
View File
@@ -20,23 +20,46 @@ const (
minFwPacketLen = 4
)
var ErrOutOfWindow = errors.New("out of window packet")
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
// TODO: record metrics for rx holepunch/punchy packets?
if len(packet) > 1 {
f.l.Info("Error while parsing inbound packet",
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while parsing inbound packet",
"from", via,
"error", err,
"packet", packet,
)
}
}
return
}
if h.Version != header.Version {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Unexpected header version received", "from", via)
}
return
}
// Check before processing to see if this is a expected type/subtype
if !h.IsValidSubType() {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Unexpected packet received", "from", via)
}
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Refusing to process double encrypted packet", "from", via)
}
@@ -44,31 +67,108 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
}
}
// don't keep Rx metrics for message type, since you can see those in the tun metrics
if h.Type != header.Message {
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
}
// Unencrypted packets
switch h.Type {
case header.Handshake:
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.handleRecvError(via.UdpAddr, h)
return
}
// Relay packets are special
isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay)
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
if isMessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
// At this point we should have a valid existing tunnel, verify and send
// recvError if necessary
if hostinfo == nil || hostinfo.ConnectionState == nil {
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
}
return
}
// All remaining packets are encrypted
ci := hostinfo.ConnectionState
if !ci.window.Check(f.l, h.MessageCounter) {
return
}
// Relay packets are special
if isMessageRelay {
f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache)
return
}
out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Failed to decrypt packet",
"error", err,
"from", via,
"header", h,
)
}
return
}
// Roam before we respond
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo)
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, via, h) {
switch h.Subtype {
case header.MessageNone:
f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h)
return
}
case header.LightHouse:
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f)
case header.Test:
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
case header.TestReply:
// No-op, useful for the Roaming and connectionManager side-effects above
case header.TestRequest:
f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h)
return
}
case header.MessageRelay:
case header.CloseTunnel:
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
f.closeTunnel(hostinfo)
case header.Control:
f.relayManager.HandleControlMsg(hostinfo, out, f)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h)
}
}
func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
@@ -76,6 +176,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// which will gracefully fail in the DecryptDanger call.
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
return
@@ -93,8 +194,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
"vpnAddrs", hostinfo.vpnAddrs,
"remoteIndex", h.RemoteIndex,
"relayRemoteIndex", h.RemoteIndex,
)
return
}
@@ -111,15 +211,14 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
IsRelayed: true,
}
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
"relayTo", relay.PeerAddr,
"relayFrom", hostinfo.vpnAddrs[0],
"error", err,
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
)
return
}
@@ -131,9 +230,14 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
return
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
return
default:
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type)
}
return
}
} else {
hostinfo.logger(f.l).Info("Unexpected target relay state",
@@ -143,116 +247,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
)
return
}
}
}
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt test packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(via.UdpAddr, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via)
hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type)
}
return
}
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo)
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -300,23 +299,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
}
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
}
return false
}
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
}
return true
}
var (
ErrPacketTooShort = errors.New("packet is too short")
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
@@ -523,38 +505,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
}
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h)
}
return nil, errors.New("out of window packet")
return nil, ErrOutOfWindow
}
return out, nil
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
return false
}
err = newPacket(out, true, fwPacket)
func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(out, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
"error", err,
"packet", out,
)
return false
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket)
}
return false
return
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
@@ -568,15 +532,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
"reason", dropReason,
)
}
return false
return
}
f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out)
if err != nil {
f.l.Error("Failed to write to tun", "error", err)
}
return true
}
func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
+358
View File
@@ -0,0 +1,358 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"errors"
"fmt"
"log/slog"
"runtime"
"strings"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
// networkCategory mirrors NLM_NETWORK_CATEGORY from netlistmgr.h.
type networkCategory int32
const (
networkCategoryPublic networkCategory = 0
networkCategoryPrivate networkCategory = 1
networkCategoryDomainAuthenticated networkCategory = 2
)
func (c networkCategory) String() string {
switch c {
case networkCategoryPublic:
return "public"
case networkCategoryPrivate:
return "private"
case networkCategoryDomainAuthenticated:
return "domain"
}
return fmt.Sprintf("unknown(%d)", c)
}
// parseNetworkCategory accepts the user-supplied tun.network_category. A
// second return of false means "leave the category alone".
func parseNetworkCategory(s string) (networkCategory, bool, error) {
switch strings.ToLower(strings.TrimSpace(s)) {
case "", "unset":
return 0, false, nil
case "public":
return networkCategoryPublic, true, nil
case "private":
return networkCategoryPrivate, true, nil
case "domain", "domainauthenticated":
return networkCategoryDomainAuthenticated, true, nil
}
return 0, false, fmt.Errorf("unknown tun.network_category %q (expected public, private, domain, or unset)", s)
}
// CLSID_NetworkListManager {DCB00C01-570F-4A9B-8D69-199FDBA5723B}
var clsidNetworkListManager = windows.GUID{
Data1: 0xDCB00C01, Data2: 0x570F, Data3: 0x4A9B,
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
}
// IID_INetworkListManager {DCB00000-570F-4A9B-8D69-199FDBA5723B}
var iidINetworkListManager = windows.GUID{
Data1: 0xDCB00000, Data2: 0x570F, Data3: 0x4A9B,
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
}
// x/sys/windows doesn't expose CoCreateInstance, so we bind it ourselves.
var procCoCreateInstance = windows.NewLazySystemDLL("ole32.dll").NewProc("CoCreateInstance")
const clsCtxAll = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER |
windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER
const (
hrSFALSE = 0x00000001
hrRPCEChangedMode = 0x80010106
)
type hresult uint32
func (h hresult) failed() bool { return int32(h) < 0 }
func (h hresult) String() string {
return fmt.Sprintf("HRESULT 0x%08x", uint32(h))
}
var errAdapterNotFound = errors.New("adapter not present in network connections enumeration")
// Vtable layouts. Slot order must match the declaration order in netlistmgr.h.
// All NLM interfaces here derive from IDispatch, which derives from IUnknown.
type iUnknownVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
}
type iDispatchVtbl struct {
iUnknownVtbl
GetTypeInfoCount uintptr
GetTypeInfo uintptr
GetIDsOfNames uintptr
Invoke uintptr
}
type iNetworkListManagerVtbl struct {
iDispatchVtbl
GetNetworks uintptr
GetNetwork uintptr
GetNetworkConnections uintptr
GetNetworkConnection uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
}
type iNetworkListManager struct{ Vtbl *iNetworkListManagerVtbl }
func (n *iNetworkListManager) Release() {
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
}
func (n *iNetworkListManager) GetNetworkConnections() (*iEnumNetworkConnections, error) {
var enum *iEnumNetworkConnections
r1, _, _ := syscall.SyscallN(n.Vtbl.GetNetworkConnections,
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&enum)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("INetworkListManager.GetNetworkConnections: %s", hr)
}
return enum, nil
}
type iEnumNetworkConnectionsVtbl struct {
iDispatchVtbl
NewEnum uintptr
Next uintptr
Skip uintptr
Reset uintptr
Clone uintptr
}
type iEnumNetworkConnections struct{ Vtbl *iEnumNetworkConnectionsVtbl }
func (e *iEnumNetworkConnections) Release() {
syscall.SyscallN(e.Vtbl.Release, uintptr(unsafe.Pointer(e)))
}
// Next returns the next connection, or (nil, nil) at the end of the enumeration.
func (e *iEnumNetworkConnections) Next() (*iNetworkConnection, error) {
var conn *iNetworkConnection
var fetched uint32
r1, _, _ := syscall.SyscallN(e.Vtbl.Next,
uintptr(unsafe.Pointer(e)), 1,
uintptr(unsafe.Pointer(&conn)), uintptr(unsafe.Pointer(&fetched)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("IEnumNetworkConnections.Next: %s", hr)
}
if fetched == 0 {
return nil, nil
}
return conn, nil
}
type iNetworkConnectionVtbl struct {
iDispatchVtbl
GetNetwork uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
GetConnectionId uintptr
GetAdapterId uintptr
GetDomainType uintptr
}
type iNetworkConnection struct{ Vtbl *iNetworkConnectionVtbl }
func (c *iNetworkConnection) Release() {
syscall.SyscallN(c.Vtbl.Release, uintptr(unsafe.Pointer(c)))
}
func (c *iNetworkConnection) GetAdapterId() (windows.GUID, error) {
var g windows.GUID
r1, _, _ := syscall.SyscallN(c.Vtbl.GetAdapterId,
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&g)),
)
if hr := hresult(r1); hr.failed() {
return windows.GUID{}, fmt.Errorf("INetworkConnection.GetAdapterId: %s", hr)
}
return g, nil
}
func (c *iNetworkConnection) GetNetwork() (*iNetwork, error) {
var net *iNetwork
r1, _, _ := syscall.SyscallN(c.Vtbl.GetNetwork,
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&net)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("INetworkConnection.GetNetwork: %s", hr)
}
return net, nil
}
type iNetworkVtbl struct {
iDispatchVtbl
GetName uintptr
SetName uintptr
GetDescription uintptr
SetDescription uintptr
GetNetworkId uintptr
GetDomainType uintptr
GetNetworkConnections uintptr
GetTimeCreatedAndConnected uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
GetCategory uintptr
SetCategory uintptr
}
type iNetwork struct{ Vtbl *iNetworkVtbl }
func (n *iNetwork) Release() {
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
}
func (n *iNetwork) GetCategory() (networkCategory, error) {
var c networkCategory
r1, _, _ := syscall.SyscallN(n.Vtbl.GetCategory,
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&c)),
)
if hr := hresult(r1); hr.failed() {
return 0, fmt.Errorf("INetwork.GetCategory: %s", hr)
}
return c, nil
}
func (n *iNetwork) SetCategory(c networkCategory) error {
r1, _, _ := syscall.SyscallN(n.Vtbl.SetCategory,
uintptr(unsafe.Pointer(n)), uintptr(int32(c)),
)
if hr := hresult(r1); hr.failed() {
return fmt.Errorf("INetwork.SetCategory: %s", hr)
}
return nil
}
// coInit initializes COM for the current OS thread. The returned function must
// be deferred to balance a successful init. RPC_E_CHANGED_MODE means COM is
// already initialized in a different mode on this thread, which is still fine
// for our calls but we must not Uninitialize in that case.
func coInit() (func(), error) {
err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED)
if err == nil {
return windows.CoUninitialize, nil
}
if e, ok := err.(syscall.Errno); ok {
switch uint32(e) {
case hrSFALSE:
return windows.CoUninitialize, nil
case hrRPCEChangedMode:
return func() {}, nil
}
}
return nil, fmt.Errorf("CoInitializeEx: %w", err)
}
func createNetworkListManager() (*iNetworkListManager, error) {
var nlm *iNetworkListManager
r1, _, _ := procCoCreateInstance.Call(
uintptr(unsafe.Pointer(&clsidNetworkListManager)),
0,
uintptr(clsCtxAll),
uintptr(unsafe.Pointer(&iidINetworkListManager)),
uintptr(unsafe.Pointer(&nlm)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("CoCreateInstance(NetworkListManager): %s", hr)
}
return nlm, nil
}
// setNetworkCategory locates the network connection bound to adapterGUID and
// sets the category of its parent network. Returns errAdapterNotFound if the
// adapter is not yet visible in the NLM enumeration.
func setNetworkCategory(adapterGUID windows.GUID, cat networkCategory) error {
deinit, err := coInit()
if err != nil {
return err
}
defer deinit()
nlm, err := createNetworkListManager()
if err != nil {
return err
}
defer nlm.Release()
enum, err := nlm.GetNetworkConnections()
if err != nil {
return err
}
defer enum.Release()
for {
conn, err := enum.Next()
if err != nil {
return err
}
if conn == nil {
return errAdapterNotFound
}
guid, err := conn.GetAdapterId()
if err != nil || guid != adapterGUID {
conn.Release()
continue
}
net, err := conn.GetNetwork()
conn.Release()
if err != nil {
return err
}
err = net.SetCategory(cat)
net.Release()
return err
}
}
// applyNetworkCategory polls until the wintun adapter shows up in the NLM
// enumeration, then sets the category. Intended to run in its own goroutine.
func applyNetworkCategory(l *slog.Logger, adapterGUID windows.GUID, cat networkCategory) {
// COM Init/Uninit must be paired on the same OS thread.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
const (
attempts = 30
interval = 500 * time.Millisecond
)
for i := 0; i < attempts; i++ {
err := setNetworkCategory(adapterGUID, cat)
if err == nil {
l.Info("Set Windows network category", "category", cat.String())
return
}
if !errors.Is(err, errAdapterNotFound) {
l.Warn("Failed to set Windows network category", "error", err, "category", cat.String())
return
}
time.Sleep(interval)
}
l.Warn("Gave up waiting for adapter to appear in NLM enumeration; network category not set",
"category", cat.String(),
"waited", time.Duration(attempts)*interval,
)
}
+109
View File
@@ -0,0 +1,109 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"testing"
)
func Test_parseNetworkCategory(t *testing.T) {
cases := []struct {
in string
wantCat networkCategory
wantApply bool
wantErr bool
}{
{"", 0, false, false},
{"unset", 0, false, false},
{" UNSET ", 0, false, false},
{"private", networkCategoryPrivate, true, false},
{"Private", networkCategoryPrivate, true, false},
{" PRIVATE ", networkCategoryPrivate, true, false},
{"public", networkCategoryPublic, true, false},
{"PUBLIC", networkCategoryPublic, true, false},
{"domain", networkCategoryDomainAuthenticated, true, false},
{"DomainAuthenticated", networkCategoryDomainAuthenticated, true, false},
{"garbage", 0, false, true},
{"privates", 0, false, true},
}
for _, tc := range cases {
cat, apply, err := parseNetworkCategory(tc.in)
if (err != nil) != tc.wantErr {
t.Errorf("parseNetworkCategory(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr)
continue
}
if cat != tc.wantCat || apply != tc.wantApply {
t.Errorf("parseNetworkCategory(%q) = (%v, %v), want (%v, %v)", tc.in, cat, apply, tc.wantCat, tc.wantApply)
}
}
}
// Test_NLM_round_trip exercises every COM call path used by setNetworkCategory
// without mutating the host's network state. It validates the CLSID/IID
// constants and every vtable index by enumerating connections, fetching the
// adapter id and parent network, reading the current category, and writing it
// back unchanged.
//
// Requires Windows but does not require admin or the wintun driver. Skips if
// no network connections are available (unlikely outside of an isolated
// container).
func Test_NLM_round_trip(t *testing.T) {
deinit, err := coInit()
if err != nil {
t.Fatalf("coInit: %v", err)
}
defer deinit()
nlm, err := createNetworkListManager()
if err != nil {
t.Fatalf("createNetworkListManager: %v", err)
}
defer nlm.Release()
enum, err := nlm.GetNetworkConnections()
if err != nil {
t.Fatalf("GetNetworkConnections: %v", err)
}
defer enum.Release()
saw := 0
for {
conn, err := enum.Next()
if err != nil {
t.Fatalf("EnumNetworkConnections.Next: %v", err)
}
if conn == nil {
break
}
saw++
if _, err := conn.GetAdapterId(); err != nil {
conn.Release()
t.Fatalf("INetworkConnection.GetAdapterId: %v", err)
}
net, err := conn.GetNetwork()
conn.Release()
if err != nil {
t.Fatalf("INetworkConnection.GetNetwork: %v", err)
}
cat, err := net.GetCategory()
if err != nil {
net.Release()
t.Fatalf("INetwork.GetCategory: %v", err)
}
// Set to the current value so the host's NLM state is unchanged but
// SetCategory's vtable slot is still validated end-to-end.
if err := net.SetCategory(cat); err != nil {
net.Release()
t.Fatalf("INetwork.SetCategory(%v): %v", cat, err)
}
net.Release()
}
if saw == 0 {
t.Skip("no NLM network connections available; skipping round-trip")
}
}
+23
View File
@@ -0,0 +1,23 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
package overlay
import (
"log/slog"
"github.com/slackhq/nebula/wfp"
)
// installInterfaceBypass installs a WFP PERMIT filter scoped to the wintun interface LUID so inbound traffic on the
// nebula adapter bypasses Windows Defender Firewall.
func installInterfaceBypass(l *slog.Logger, luid uint64) closer {
s, err := wfp.PermitInterface(luid)
if err != nil {
l.Warn("Failed to install WFP bypass filters on nebula interface", "error", err)
return nil
}
l.Info("Installed WFP filters bypassing Windows Defender Firewall on nebula interface")
return s
}
+11
View File
@@ -0,0 +1,11 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import "log/slog"
// installInterfaceBypass is a no-op on windows-386 because we don't currently build for it.
func installInterfaceBypass(_ *slog.Logger, _ uint64) closer {
return nil
}
+49 -5
View File
@@ -15,6 +15,7 @@ import (
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/udp"
)
type TestTun struct {
@@ -54,9 +55,12 @@ func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTu
return nil, fmt.Errorf("newTunFromFd not supported")
}
// Send will place a byte array onto the receive queue for nebula to consume
// Send will place a byte array onto the receive queue for nebula to consume.
// These are unencrypted ip layer frames destined for another nebula node.
// packets should exit the udp side, capture them with udpConn.Get
// packets should exit the udp side, capture them with udpConn.Get.
//
// Send copies the input via the freelist, so the caller is free to mutate
// or reuse it after the call returns.
func (t *TestTun) Send(packet []byte) {
if t.closed.Load() {
return
@@ -65,7 +69,9 @@ func (t *TestTun) Send(packet []byte) {
if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Tun receiving injected packet", "dataLen", len(packet))
}
t.rxPackets <- packet
buf := acquireTunBuf(len(packet))
copy(buf, packet)
t.rxPackets <- buf
}
// Get will pull an unencrypted ip layer frame from the transmit queue
@@ -110,12 +116,44 @@ func (t *TestTun) Write(b []byte) (n int, err error) {
return 0, io.ErrClosedPipe
}
packet := make([]byte, len(b), len(b))
packet := acquireTunBuf(len(b))
copy(packet, b)
t.TxPackets <- packet
return len(b), nil
}
// ReleaseTunBuf returns a slice from TxPackets to the harness freelist, don't use the bytes after the call.
// Channel-backed instead of sync.Pool because putting a []byte in a sync.Pool escapes the slice header to heap.
func ReleaseTunBuf(b []byte) {
if b == nil {
return
}
select {
case tunBufFreelist <- b:
default:
// Freelist full; drop the buffer for the GC.
}
}
// tunBufFreelist retains the backing arrays for TestTun.Write so steady-state allocation drops to zero once the
// freelist has saturated for the current MTU.
var tunBufFreelist = make(chan []byte, 64)
func acquireTunBuf(n int) []byte {
var b []byte
select {
case b = <-tunBufFreelist:
default:
b = make([]byte, 0, udp.MTU)
}
if cap(b) < n {
b = make([]byte, n)
} else {
b = b[:n]
}
return b
}
func (t *TestTun) Close() error {
if t.closed.CompareAndSwap(false, true) {
close(t.rxPackets)
@@ -129,8 +167,14 @@ func (t *TestTun) Read(b []byte) (int, error) {
if !ok {
return 0, os.ErrClosed
}
n := len(p)
copy(b, p)
return len(p), nil
// Send always pushes a freelist-acquired slice, return it once we've copied the bytes into the caller's buffer.
select {
case tunBufFreelist <- p:
default:
}
return n, nil
}
func (t *TestTun) SupportsMultiqueue() bool {
+44 -6
View File
@@ -25,6 +25,10 @@ import (
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
type closer interface {
Close()
}
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
@@ -33,6 +37,11 @@ type winTun struct {
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
guid windows.GUID
networkCategory networkCategory
setCategory bool
bypassWDF bool
wdfBypass closer
l *slog.Logger
tun *wintun.NativeTun
@@ -54,10 +63,19 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
cat, setCat, err := parseNetworkCategory(c.GetString("tun.network_category", "private"))
if err != nil {
return nil, err
}
t := &winTun{
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
guid: *guid,
networkCategory: cat,
setCategory: setCat,
bypassWDF: c.GetBool("tun.windows_bypass_wdf", true),
l: l,
}
@@ -142,6 +160,17 @@ func (t *winTun) Activate() error {
return err
}
if t.setCategory {
// The wintun adapter takes a moment to register with the Network List
// Manager, so we apply the category in the background and retry until
// it shows up.
go applyNetworkCategory(t.l, t.guid, t.networkCategory)
}
if t.bypassWDF {
t.wdfBypass = installInterfaceBypass(t.l, uint64(t.tun.LUID()))
}
return nil
}
@@ -156,11 +185,8 @@ func (t *winTun) addRoutes(logErrors bool) error {
continue
}
// Add our unsafe route
// Windows does not support multipath routes natively, so we install only a single route.
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
// Add our unsafe route as an on-link route to the nebula tun device.
err := luid.AddRoute(r.Cidr, unspecifiedNextHop(r.Cidr), uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
@@ -206,7 +232,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
}
// See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
err := luid.DeleteRoute(r.Cidr, unspecifiedNextHop(r.Cidr))
if err != nil {
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
@@ -258,9 +284,21 @@ func (t *winTun) Close() error {
_ = luid.FlushDNS(windows.AF_INET)
_ = luid.FlushDNS(windows.AF_INET6)
if t.wdfBypass != nil {
t.wdfBypass.Close()
t.wdfBypass = nil
}
return t.tun.Close()
}
func unspecifiedNextHop(p netip.Prefix) netip.Addr {
if p.Addr().Is4() {
return netip.IPv4Unspecified()
}
return netip.IPv6Unspecified()
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
+3 -5
View File
@@ -99,12 +99,10 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
var currentState *CertState
if initial {
cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
case "aes", "chachapoly":
// Each post-handshake CipherState in noiseutil hardcodes its own
// nonce endianness now, so there's nothing to set up here.
default:
return util.NewContextualError(
"unknown cipher",
+158 -33
View File
@@ -1,24 +1,70 @@
package nebula
import (
"context"
"log/slog"
"net/netip"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
)
// holepunchQueueSize buffers the channel that pending holepunchJobs land on after their delay timer fires.
const holepunchQueueSize = 64
// holepunchJob is one scheduled item delivered to the worker goroutine.
// - target valid -> send a UDP punch to target. vpnAddr, if set, is the peer's vpn addr carried for log context.
// - target invalid, vpnAddr valid -> send an encrypted test packet to vpnAddr (a "punchback").
type holepunchJob struct {
target netip.AddrPort
vpnAddr netip.Addr
}
// lighthouseChecker is the slice of LightHouse that Punchy actually needs.
// Defined here so Punchy doesn't take a *LightHouse dependency (LightHouse
// already holds a *Punchy, and the bidirectional pointer reference is awkward
// even within the same package). Tests can also substitute a fake.
type lighthouseChecker interface {
IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool
}
type Punchy struct {
punch atomic.Bool
respond atomic.Bool
delay atomic.Int64
respondDelay atomic.Int64
punchEverything atomic.Bool
sched *Scheduler[holepunchJob]
punchConn udp.Conn
metricHolepunchTx metrics.Counter
metricPunchyTx metrics.Counter
ctx context.Context
ifce EncWriter
hm *HostMap
lh lighthouseChecker
l *slog.Logger
}
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
p := &Punchy{l: l}
func NewPunchyFromConfig(l *slog.Logger, c *config.C, punchConn udp.Conn) *Punchy {
p := &Punchy{
l: l,
punchConn: punchConn,
sched: NewScheduler[holepunchJob](holepunchQueueSize),
metricPunchyTx: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
}
if c.GetBool("stats.lighthouse_metrics", false) {
p.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
} else {
p.metricHolepunchTx = metrics.NilCounter{}
}
p.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
@@ -29,7 +75,7 @@ func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
}
func (p *Punchy) reload(c *config.C, initial bool) {
if initial {
if initial || c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
var yes bool
if c.IsSet("punchy.punch") {
yes = c.GetBool("punchy.punch", false)
@@ -38,16 +84,15 @@ func (p *Punchy) reload(c *config.C, initial bool) {
yes = c.GetBool("punchy", false)
}
p.punch.Store(yes)
if yes {
old := p.punch.Swap(yes)
switch {
case initial && yes:
p.l.Info("punchy enabled")
} else {
case initial:
p.l.Info("punchy disabled")
case old != yes:
p.l.Info("punchy.punch changed", "punch", yes)
}
} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
}
if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") {
@@ -59,52 +104,132 @@ func (p *Punchy) reload(c *config.C, initial bool) {
yes = c.GetBool("punch_back", false)
}
p.respond.Store(yes)
if !initial {
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
old := p.respond.Swap(yes)
if !initial && old != yes {
p.l.Info("punchy.respond changed", "respond", yes)
}
}
//NOTE: this will not apply to any in progress operations, only the next one
if initial || c.HasChanged("punchy.delay") {
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
if !initial {
p.l.Info("punchy.delay changed", "delay", p.GetDelay())
newDelay := int64(c.GetDuration("punchy.delay", time.Second))
old := p.delay.Swap(newDelay)
if !initial && old != newDelay {
p.l.Info("punchy.delay changed", "delay", time.Duration(newDelay))
}
}
if initial || c.HasChanged("punchy.target_all_remotes") {
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
if !initial {
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
yes := c.GetBool("punchy.target_all_remotes", false)
old := p.punchEverything.Swap(yes)
if !initial && old != yes {
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", yes)
}
}
if initial || c.HasChanged("punchy.respond_delay") {
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
if !initial {
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
newDelay := int64(c.GetDuration("punchy.respond_delay", 5*time.Second))
old := p.respondDelay.Swap(newDelay)
if !initial && old != newDelay {
p.l.Info("punchy.respond_delay changed", "respond_delay", time.Duration(newDelay))
}
}
}
func (p *Punchy) GetPunch() bool {
return p.punch.Load()
// Schedule queues a punch packet to target, to be sent after the configured delay.
// vpnAddr is the peer's vpn addr, used for log context when the packet actually fires.
// No-op if target is not a valid AddrPort or if Start has not yet been called. Safe to call from any goroutine.
func (p *Punchy) Schedule(target netip.AddrPort, vpnAddr netip.Addr) {
if !target.IsValid() || p.ctx == nil {
return
}
p.scheduleJob(holepunchJob{target: target, vpnAddr: vpnAddr}, time.Duration(p.delay.Load()))
}
func (p *Punchy) GetRespond() bool {
return p.respond.Load()
// ScheduleRespond queues a punchback test packet to vpnAddr after the configured respond delay,
// gated on punchy.respond. No-op when respond is disabled or before Start has been called.
func (p *Punchy) ScheduleRespond(vpnAddr netip.Addr) {
if !p.respond.Load() || p.ctx == nil {
return
}
p.scheduleJob(holepunchJob{vpnAddr: vpnAddr}, time.Duration(p.respondDelay.Load()))
}
func (p *Punchy) GetDelay() time.Duration {
return (time.Duration)(p.delay.Load())
// scheduleJob delegates to the pooled Scheduler.
// The callback observes p.ctx so a job that becomes due after Stop is dropped instead of queued.
func (p *Punchy) scheduleJob(job holepunchJob, delay time.Duration) {
p.sched.Schedule(p.ctx, job, delay)
}
func (p *Punchy) GetRespondDelay() time.Duration {
return (time.Duration)(p.respondDelay.Load())
// SendPunch sends an immediate keepalive punch for an idle hostinfo.
// The configured punchy.target_all_remotes mode picks the targets. Gated on punchy.punch and the lighthouse-skip rule
// (lighthouses don't get keepalive punches because the regular update interval keeps their NAT state warm).
func (p *Punchy) SendPunch(hostinfo *HostInfo) {
if !p.punch.Load() {
return
}
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
return
}
func (p *Punchy) GetTargetEverything() bool {
return p.punchEverything.Load()
if p.punchEverything.Load() {
p.sendPunchToAllRemotes(hostinfo)
} else if hostinfo.remote.IsValid() {
p.metricPunchyTx.Inc(1)
p.punchConn.WriteTo([]byte{1}, hostinfo.remote)
}
}
// SendPunchToAll punches every known remote for hostinfo, but only when punchy.target_all_remotes is enabled.
// The connection manager calls this during outbound-only traffic: the outbound traffic itself keeps the primary's
// NAT state warm, but non-primary remotes need separate refresh, so we fan out to all of them (the redundant
// primary punch is harmless). Gated on punchy.punch and the lighthouse-skip rule.
func (p *Punchy) SendPunchToAll(hostinfo *HostInfo) {
if !p.punchEverything.Load() {
return
}
if !p.punch.Load() {
return
}
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
return
}
p.sendPunchToAllRemotes(hostinfo)
}
func (p *Punchy) sendPunchToAllRemotes(hostinfo *HostInfo) {
hostinfo.remotes.ForEach(p.hm.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
p.metricPunchyTx.Inc(1)
p.punchConn.WriteTo([]byte{1}, addr)
})
}
// Start wires the runtime dependencies and spawns the scheduler worker.
func (p *Punchy) Start(ctx context.Context, ifce EncWriter, hm *HostMap, lh lighthouseChecker) {
p.ctx = ctx
p.ifce = ifce
p.hm = hm
p.lh = lh
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
empty := []byte{0}
go p.sched.Run(ctx, func(job holepunchJob) {
switch {
case job.target.IsValid():
if p.l.Enabled(context.Background(), slog.LevelDebug) {
p.l.Debug("Punching", "target", job.target, "vpnAddr", job.vpnAddr)
}
p.metricHolepunchTx.Inc(1)
p.punchConn.WriteTo(empty, job.target)
case job.vpnAddr.IsValid():
// A nebula test packet to the host trying to contact us.
// In the case of a double nat or other difficult scenario, this may help establish a tunnel.
if p.l.Enabled(context.Background(), slog.LevelDebug) {
p.l.Debug("Sending a nebula test packet", "vpnAddr", job.vpnAddr)
}
p.ifce.SendMessageToVpnAddr(header.Test, header.TestRequest, job.vpnAddr, []byte(""), nb, out)
}
})
}
+40 -41
View File
@@ -17,42 +17,42 @@ func TestNewPunchyFromConfig(t *testing.T) {
c := config.NewC(l)
// Test defaults
p := NewPunchyFromConfig(test.NewLogger(), c)
assert.False(t, p.GetPunch())
assert.False(t, p.GetRespond())
assert.Equal(t, time.Second, p.GetDelay())
assert.Equal(t, 5*time.Second, p.GetRespondDelay())
p := NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.False(t, p.punch.Load())
assert.False(t, p.respond.Load())
assert.Equal(t, time.Second, time.Duration(p.delay.Load()))
assert.Equal(t, 5*time.Second, time.Duration(p.respondDelay.Load()))
// punchy deprecation
c.Settings["punchy"] = true
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.punch.Load())
// punchy.punch
c.Settings["punchy"] = map[string]any{"punch": true}
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.punch.Load())
// punch_back deprecation
c.Settings["punch_back"] = true
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.respond.Load())
// punchy.respond
c.Settings["punchy"] = map[string]any{"respond": true}
c.Settings["punch_back"] = false
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.respond.Load())
// punchy.delay
c.Settings["punchy"] = map[string]any{"delay": "1m"}
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetDelay())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.Equal(t, time.Minute, time.Duration(p.delay.Load()))
// punchy.respond_delay
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetRespondDelay())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.Equal(t, time.Minute, time.Duration(p.respondDelay.Load()))
}
func TestPunchy_reload(t *testing.T) {
@@ -61,35 +61,34 @@ func TestPunchy_reload(t *testing.T) {
delay, _ := time.ParseDuration("1m")
require.NoError(t, c.LoadString(`
punchy:
punch: false
delay: 1m
respond: false
`))
p := NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, delay, p.GetDelay())
assert.False(t, p.GetRespond())
p := NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.False(t, p.punch.Load())
assert.Equal(t, delay, time.Duration(p.delay.Load()))
assert.False(t, p.respond.Load())
newDelay, _ := time.ParseDuration("10m")
require.NoError(t, c.ReloadConfigString(`
punchy:
punch: true
delay: 10m
respond: true
`))
p.reload(c, false)
assert.Equal(t, newDelay, p.GetDelay())
assert.True(t, p.GetRespond())
assert.True(t, p.punch.Load())
assert.Equal(t, newDelay, time.Duration(p.delay.Load()))
assert.True(t, p.respond.Load())
}
// The tests below pin the shape of each log line Punchy produces so changes
// cannot silently break whatever operators are grepping for. The assertions
// are on the structured message + attrs (e.g. "punchy.respond changed" with
// a respond=true field) rather than a formatted string.
//
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
// not supported" warning whenever any key under punchy changes, because of
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
// punchy form. The tests filter by message rather than asserting total
// entry counts so that warning is tolerated without being locked into
// the format.
// a respond=true field) rather than a formatted string. Tests filter by
// message rather than asserting total entry counts so unrelated info lines
// are tolerated without being locked into the format.
type capturedEntry struct {
Level slog.Level
@@ -145,7 +144,7 @@ func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
entry := findEntry(t, hook.entries, "punchy enabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
@@ -157,32 +156,32 @@ func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
entry := findEntry(t, hook.entries, "punchy disabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
func TestPunchy_LogFormat_ReloadPunch(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
assert.Equal(t, slog.LevelWarn, entry.Level)
assert.Empty(t, entry.Attrs)
entry := findEntry(t, hook.entries, "punchy.punch changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"punch": true}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
@@ -196,7 +195,7 @@ func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
@@ -210,7 +209,7 @@ func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
@@ -224,7 +223,7 @@ func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
+175 -4
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"net/netip"
"slices"
"sync/atomic"
"github.com/slackhq/nebula/cert"
@@ -18,6 +19,7 @@ type relayManager struct {
l *slog.Logger
hostmap *HostMap
amRelay atomic.Bool
useRelays atomic.Bool
}
func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager {
@@ -36,8 +38,10 @@ func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *c
}
func (rm *relayManager) reload(c *config.C, initial bool) error {
if initial || c.HasChanged("relay.am_relay") {
rm.setAmRelay(c.GetBool("relay.am_relay", false))
if initial || c.HasChanged("relay.am_relay") || c.HasChanged("relay.use_relays") {
amRelay := c.GetBool("relay.am_relay", false)
rm.amRelay.Store(amRelay)
rm.useRelays.Store(c.GetBool("relay.use_relays", true) && !amRelay)
}
return nil
}
@@ -46,8 +50,175 @@ func (rm *relayManager) GetAmRelay() bool {
return rm.amRelay.Load()
}
func (rm *relayManager) setAmRelay(v bool) {
rm.amRelay.Store(v)
func (rm *relayManager) GetUseRelays() bool {
return rm.useRelays.Load()
}
// StartRelays drives the relay-establishment side of an outbound handshake attempt.
// For each candidate relay it either kicks off a handshake to the relay, sends a CreateRelayRequest, retransmits
// one that may have been lost, or, once the relay is Established, forwards the in-progress
// stage 0 handshake packet for vpnIp through it.
func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hh *HandshakeHostInfo, stage0 []byte) {
hostinfo := hh.hostinfo
if !rm.GetUseRelays() || len(hostinfo.remotes.relays) == 0 {
hh.lastRelays = nil
return
}
relays := hostinfo.remotes.relays
listLevel := slog.LevelDebug
prior := hh.lastRelays
if !slices.Equal(relays, prior) {
listLevel = slog.LevelInfo
hh.lastRelays = slices.Clone(relays)
}
hl := hostinfo.logger(rm.l)
hl.Log(context.Background(), listLevel, "Attempt to relay through hosts", "relays", relays)
// Send a RelayRequest to all known Relay IP's
for _, relay := range relays {
// Don't relay through the host I'm trying to connect to
if relay == vpnIp {
continue
}
// Don't relay to myself
if f.myVpnAddrsTable.Contains(relay) {
continue
}
// Each relay's per-attempt log fires at Info on the first time we hit it and Debug after that.
level := slog.LevelInfo
if slices.Contains(prior, relay) {
level = slog.LevelDebug
}
relayHostInfo := rm.hostmap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hl.Log(context.Background(), level, "Establish tunnel to relay target", "relay", relay.String())
f.Handshake(relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through
existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
if !ok {
// No relays exist or requested yet.
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(rm.l, relayHostInfo, rm.hostmap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hl.Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
}
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !f.myVpnAddrs[0].Is4() {
hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hl.Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hl.Error("Failed to marshal Control message to create relay", "error", err)
} else {
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.Log(context.Background(), level, "send CreateRelayRequest",
"relayFrom", f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", idx,
"relay", relay,
)
}
}
continue
}
switch existingRelay.State {
case Established:
hl.Log(context.Background(), level, "Send handshake via relay", "relay", relay.String())
f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested:
hl.Log(context.Background(), level, "Re-send CreateRelay request", "relay", relay.String())
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !f.myVpnAddrs[0].Is4() {
hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hl.Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hl.Error("Failed to marshal Control message to create relay", "error", err)
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.Log(context.Background(), level, "send CreateRelayRequest",
"relayFrom", f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", existingRelay.LocalIndex,
"relay", relay,
)
}
case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default:
hl.Error("Relay unexpected state",
"vpnIp", vpnIp,
"state", existingRelay.State,
"relay", relay,
)
}
}
}
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
+97
View File
@@ -0,0 +1,97 @@
package nebula
import (
"bytes"
"log/slog"
"net/netip"
"testing"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
// TestStartRelaysLogDedupe verifies that repeated attempts with the same relay set drop the log
// chatter to Debug, mirroring how the normal handshake retry loop quiets down once it's already
// announced its targets.
func TestStartRelaysLogDedupe(t *testing.T) {
vpnIp := netip.MustParseAddr("100.64.99.4")
otherRelay := netip.MustParseAddr("100.64.99.5")
newHH := func() *HandshakeHostInfo {
// Use the target's own vpnIp as the "relay" so the loop body skips it without
// touching any sender-side state. That isolates the test to the level-selection
// behavior of the top-level "Attempt to relay through hosts" log.
hostinfo := &HostInfo{
vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1,
remotes: NewRemoteList([]netip.Addr{vpnIp}, nil),
}
hostinfo.remotes.relays = []netip.Addr{vpnIp}
return &HandshakeHostInfo{hostinfo: hostinfo}
}
// Park any extra relay addresses we'll introduce mid-test in myVpnAddrsTable so the loop
// body always skips before touching f.Handshake (which would need a real handshakeManager).
addrTable := new(bart.Lite)
addrTable.Insert(netip.PrefixFrom(otherRelay, otherRelay.BitLen()))
f := &Interface{myVpnAddrsTable: addrTable}
newRM := func(buf *bytes.Buffer) *relayManager {
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
rm := &relayManager{l: l, hostmap: newHostMap(l)}
rm.useRelays.Store(true)
return rm
}
const msg = `msg="Attempt to relay through hosts"`
t.Run("first attempt logs at Info", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
hh := newHH()
rm.StartRelays(f, vpnIp, hh, nil)
assert.Equal(t, []netip.Addr{vpnIp}, hh.lastRelays, "lastRelays should record the relay set we just attempted")
assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info level on first attempt")
})
t.Run("repeat attempt with same relays drops to Debug", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
hh := newHH()
rm.StartRelays(f, vpnIp, hh, nil)
first := append([]netip.Addr(nil), hh.lastRelays...)
buf.Reset()
rm.StartRelays(f, vpnIp, hh, nil)
assert.Equal(t, first, hh.lastRelays)
assert.Contains(t, buf.String(), "level=DEBUG "+msg, "expected Debug level on identical retry")
assert.NotContains(t, buf.String(), "level=INFO "+msg, "Info should not fire on identical retry")
})
t.Run("changed relay list bumps back to Info", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
hh := newHH()
rm.StartRelays(f, vpnIp, hh, nil)
buf.Reset()
// The lighthouse handed us a new set this round.
hh.hostinfo.remotes.relays = []netip.Addr{vpnIp, otherRelay}
rm.StartRelays(f, vpnIp, hh, nil)
assert.Equal(t, []netip.Addr{vpnIp, otherRelay}, hh.lastRelays)
assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info when the relay list changes")
})
t.Run("disabled relays clears lastRelays and emits no Attempt log", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
rm.useRelays.Store(false)
hh := newHH()
hh.lastRelays = []netip.Addr{vpnIp}
rm.StartRelays(f, vpnIp, hh, nil)
assert.Nil(t, hh.lastRelays, "with relays disabled lastRelays should be cleared")
assert.NotContains(t, buf.String(), msg, "should not log when we shortcut out")
})
}
+25
View File
@@ -239,6 +239,31 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
r.hr = hr
}
// ResetForOwner zeros the reported address slices for the given owner and marks the addrs list dirty.
// Any pending hostname resolution will be canceled.
func (r *RemoteList) ResetForOwner(ownerVpnAddr netip.Addr) {
r.Lock()
defer r.Unlock()
r.hr.Cancel()
if c, ok := r.cache[ownerVpnAddr]; ok {
if c.v4 != nil {
c.v4.reported = c.v4.reported[:0]
}
if c.v6 != nil {
c.v6.reported = c.v6.reported[:0]
}
}
r.shouldRebuild = true
}
// ClearHostnameResults cancels the in-flight DNS resolver goroutine (if any) and drops the resolved IP cache.
func (r *RemoteList) ClearHostnameResults() {
r.Lock()
defer r.Unlock()
r.unlockedSetHostnamesResults(nil)
r.shouldRebuild = true
}
// Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
+89
View File
@@ -6,8 +6,22 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// trackedHostnameResults builds a *hostnamesResults with a known cancel function and a
// pre-populated ips map so tests can assert cancellation and verify previously-resolved
// IPs survive a cancel without spinning up a real DNS resolver.
func trackedHostnameResults(cancelFn func(), addrs ...string) *hostnamesResults {
hr := &hostnamesResults{cancelFn: cancelFn}
ips := map[netip.AddrPort]struct{}{}
for _, a := range addrs {
ips[netip.MustParseAddrPort(a)] = struct{}{}
}
hr.ips.Store(&ips)
return hr
}
func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4(
@@ -112,6 +126,81 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
}
func TestRemoteList_ResetForOwner(t *testing.T) {
ourselves := netip.MustParseAddr("10.0.0.1")
otherOwner := netip.MustParseAddr("10.0.0.2")
vpnAddr := netip.MustParseAddr("10.0.0.99")
rl := NewRemoteList([]netip.Addr{vpnAddr}, nil)
rl.unlockedSetV4(ourselves, vpnAddr,
[]*V4AddrPort{newIp4AndPortFromString("1.1.1.1:4242")},
func(netip.Addr, *V4AddrPort) bool { return true },
)
rl.unlockedSetV6(ourselves, vpnAddr,
[]*V6AddrPort{newIp6AndPortFromString("[1::1]:4242")},
func(netip.Addr, *V6AddrPort) bool { return true },
)
rl.unlockedSetV4(otherOwner, vpnAddr,
[]*V4AddrPort{newIp4AndPortFromString("2.2.2.2:4242")},
func(netip.Addr, *V4AddrPort) bool { return true },
)
canceled := 0
hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")
rl.Lock()
rl.unlockedSetHostnamesResults(hr)
rl.Unlock()
rl.ResetForOwner(ourselves)
rl.RLock()
defer rl.RUnlock()
assert.Empty(t, rl.cache[ourselves].v4.reported, "our v4 reported should be cleared")
assert.Empty(t, rl.cache[ourselves].v6.reported, "our v6 reported should be cleared")
assert.Len(t, rl.cache[otherOwner].v4.reported, 1, "other owner's contribution must be preserved")
assert.Equal(t, "2.2.2.2:4242", protoV4AddrPortToNetAddrPort(rl.cache[otherOwner].v4.reported[0]).String())
assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled")
assert.Same(t, hr, rl.hr, "hostnamesResults must be preserved so DNS-resolved IPs keep feeding addrs until replaced")
assert.NotEmpty(t, rl.hr.GetAddrs(), "previously-resolved IPs should still be readable after cancel")
assert.True(t, rl.shouldRebuild, "shouldRebuild must be set so the next Rebuild recomputes addrs")
}
func TestRemoteList_ResetForOwner_NoEntry(t *testing.T) {
// An owner with no cache entry must not panic; shouldRebuild is still set and any
// existing hostnamesResults is canceled.
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil)
canceled := 0
rl.Lock()
rl.unlockedSetHostnamesResults(trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242"))
rl.Unlock()
rl.ResetForOwner(netip.MustParseAddr("10.0.0.1"))
rl.RLock()
defer rl.RUnlock()
assert.Equal(t, 1, canceled)
assert.True(t, rl.shouldRebuild)
}
func TestRemoteList_ClearHostnameResults(t *testing.T) {
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil)
canceled := 0
hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")
rl.Lock()
rl.unlockedSetHostnamesResults(hr)
rl.Unlock()
require.NotEmpty(t, hr.GetAddrs(), "hostnamesResults should have its fastrack IPs populated")
rl.ClearHostnameResults()
rl.RLock()
defer rl.RUnlock()
assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled")
assert.Nil(t, rl.hr, "hostnamesResults should be dropped")
assert.True(t, rl.shouldRebuild)
}
func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4(
+84
View File
@@ -0,0 +1,84 @@
package nebula
import (
"context"
"sync"
"time"
)
// Scheduler is an allocation-conscious dispatch primitive for delayed work.
// Pending items are handed to time.AfterFunc, and ready items land on a worker
// channel for centralized dispatch in fire-time order.
//
// Pick a Scheduler when fire timing matters (exact deadlines, no bucketing) or when the scheduling
// rate is uneven enough that idle CPU matters. Each fire is a runtime-spawned goroutine running the callback before
// delivering to the worker, which is fine at sparse rates but adds up at line rate.
//
// Pick a TimerWheel when scheduling is high-rate and uniform: its O(1) insert, internal item cache,
// and bucket-batched dispatch are cheaper at scale.
// The caller drives the tick loop (Advance/Purge) and pays for fires at bucket boundaries rather than exact deadlines.
type Scheduler[T any] struct {
queue chan T
pool sync.Pool
}
type schedItem[T any] struct {
val T
ctx context.Context
s *Scheduler[T]
timer *time.Timer
fire func()
}
// NewScheduler builds a Scheduler whose worker channel is sized to queueSize.
// The buffer absorbs bursts of timers firing close together without
// blocking the runtime's callback goroutines on the worker.
func NewScheduler[T any](queueSize int) *Scheduler[T] {
s := &Scheduler[T]{
queue: make(chan T, queueSize),
}
s.pool.New = func() any {
si := &schedItem[T]{s: s}
// fire is allocated exactly once per pool-resident item.
// The closure captures only `si`, which stays stable for the item's lifetime.
si.fire = func() {
select {
case si.s.queue <- si.val:
case <-si.ctx.Done():
}
var zero T
si.val = zero
si.ctx = nil
si.s.pool.Put(si)
}
return si
}
return s
}
// Schedule arranges item to be delivered to the worker after delay.
// The runtime's timer heap handles the wait, so the scheduler itself burns no CPU while idle.
// The callback observes ctx: if ctx is cancelled before the timer fires, the item is dropped instead of queued.
func (s *Scheduler[T]) Schedule(ctx context.Context, item T, delay time.Duration) {
si := s.pool.Get().(*schedItem[T])
si.val = item
si.ctx = ctx
if si.timer == nil {
si.timer = time.AfterFunc(delay, si.fire)
} else {
si.timer.Reset(delay)
}
}
// Run drains the worker queue, calling fn for each item. Returns when ctx is cancelled.
// Tests that want deterministic timing should drive the queue directly rather than going through Schedule + Run.
func (s *Scheduler[T]) Run(ctx context.Context, fn func(T)) {
for {
select {
case <-ctx.Done():
return
case item := <-s.queue:
fn(item)
}
}
}
+79
View File
@@ -0,0 +1,79 @@
package nebula
import (
"context"
"testing"
"time"
)
func TestScheduler_PooledReuse(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := NewScheduler[int](16)
delivered := make(chan int, 256)
go s.Run(ctx, func(item int) { delivered <- item })
const N = 100
for i := 0; i < N; i++ {
s.Schedule(ctx, i, time.Millisecond)
}
deadline := time.After(2 * time.Second)
got := 0
for got < N {
select {
case <-delivered:
got++
case <-deadline:
t.Fatalf("only %d/%d items delivered", got, N)
}
}
}
// BenchmarkScheduler_Schedule reports allocations per Schedule call.
// In steady state the Scheduler's sync.Pool means we should see zero allocs per op once the pool warms up.
func BenchmarkScheduler_Schedule(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := NewScheduler[int](b.N)
go s.Run(ctx, func(int) {})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.Schedule(ctx, i, time.Microsecond)
}
}
// BenchmarkBareAfterFunc is the comparison baseline.
// What we'd pay per Schedule if Punchy called time.AfterFunc directly without the pooled Scheduler.
// Allocates a *time.Timer plus a closure each call.
func BenchmarkBareAfterFunc(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
queue := make(chan int, b.N)
go func() {
for {
select {
case <-ctx.Done():
return
case <-queue:
}
}
}()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
i := i
time.AfterFunc(time.Microsecond, func() {
select {
case queue <- i:
case <-ctx.Done():
}
})
}
}
+37 -20
View File
@@ -27,21 +27,20 @@ type SSHServer struct {
commands *radix.Tree
listener net.Listener
// Call the cancel() function to stop all active sessions
// ctx parents per-Run contexts. Cancelling it (e.g. via Control.Stop) tears the server down even
// across reloads, since each Run derives a fresh child rather than reusing this one directly.
ctx context.Context
cancel func()
}
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
ctx, cancel := context.WithCancel(context.Background())
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen.
// The ssh server's context is parented off the supplied ctx so cancelling it
// (e.g. on Control.Stop) tears down active sessions and closes the listener.
func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) {
s := &SSHServer{
trustedKeys: make(map[string]map[string]bool),
l: l,
commands: radix.New(),
ctx: ctx,
cancel: cancel,
}
cc := ssh.CertChecker{
@@ -151,28 +150,51 @@ func (s *SSHServer) RegisterCommand(c *Command) {
s.commands.Insert(c.Name, c)
}
// Run begins listening and accepting connections
// Run begins listening and accepting connections. Each invocation derives a fresh per-Run context
// from the constructor-supplied ctx so a Stop+Run sequence (used by config reload) starts clean
// rather than carrying a permanently-cancelled context across runs.
func (s *SSHServer) Run(addr string) error {
var err error
s.listener, err = net.Listen("tcp", addr)
if s.ctx.Err() != nil {
return s.ctx.Err()
}
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
// s.listener is the public handle Stop uses to interrupt the active run; listener (the local) is what
// this run owns. They start equal but a fast reload may overwrite s.listener with the next run's
// listener before this run's watcher fires, so each run must close its own listener via the local
// reference.
s.listener = listener
runCtx, cancel := context.WithCancel(s.ctx)
defer cancel()
// Close the listener when this run's context is cancelled. That can come from the parent
// (Control.Stop), from Run returning normally (defer cancel above), or transitively when a sibling
// run cancels through Stop closing the listener. net.Listener.Close is idempotent so a duplicate
// close from Stop is benign.
go func() {
<-runCtx.Done()
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
s.l.Warn("Failed to close the sshd listener", "error", err)
}
}()
s.l.Info("SSH server is listening", "sshListener", addr)
// Run loops until there is an error
s.run()
s.closeSessions()
s.run(runCtx, listener)
s.l.Info("SSH server stopped listening")
// We don't return an error because run logs for us
return nil
}
func (s *SSHServer) run() {
func (s *SSHServer) run(ctx context.Context, listener net.Listener) {
for {
c, err := s.listener.Accept()
c, err := listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
s.l.Warn("Error in listener, shutting down", "error", err)
@@ -184,7 +206,7 @@ func (s *SSHServer) run() {
// Ensure that a bad client doesn't hurt us by checking for the parent context
// cancellation before calling NewServerConn, and forcing the socket to close when
// the context is cancelled.
sessionContext, sessionCancel := context.WithCancel(s.ctx)
sessionContext, sessionCancel := context.WithCancel(ctx)
go func() {
<-sessionContext.Done()
c.Close()
@@ -227,14 +249,9 @@ func (s *SSHServer) run() {
}
func (s *SSHServer) Stop() {
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
if s.listener != nil {
if err := s.listener.Close(); err != nil {
s.l.Warn("Failed to close the sshd listener", "error", err)
}
}
}
func (s *SSHServer) closeSessions() {
s.cancel()
}
+17
View File
@@ -8,6 +8,23 @@ import (
// How many timer objects should be cached
const timerCacheMax = 50000
// TimerWheel is a hashed timing wheel: a fixed slot array indexed by (now + delay) % wheelLen,
// with each slot a singly linked list of items due in that bucket.
// Adds are O(1), Purges return items in arrival-within-slot order, and an internal cache of TimeoutItems
// keeps steady-state inserts allocation-free.
//
// The TimerWheel does not handle concurrency or lifecycle on its own.
// Callers drive Advance/Purge from their own ticker loop, take their own locks (or use LockingTimerWheel),
// and decide whether to keep ticking when the wheel is empty.
//
// Pick a TimerWheel when scheduling is high-rate and uniform: line-rate conntrack inserts,
// per-tunnel traffic checks at fixed intervals. O(1) insert plus the item cache means the hot path doesn't allocate.
// Items added in the same tick are dispatched together when that slot rotates current,
// which amortizes the cost of waking the worker.
//
// Pick a Scheduler when delay precision matters or scheduling is sparse or uneven.
// The wheel rounds requested timeouts up to its tick resolution and clamps anything beyond its wheel duration;
// both are silent in this implementation.
type TimerWheel[T any] struct {
// Current tick
current int
+1 -2
View File
@@ -5,12 +5,11 @@ package udp
import (
"fmt"
"log/slog"
"net"
"net/netip"
"syscall"
"log/slog"
"golang.org/x/sys/unix"
)
+1 -2
View File
@@ -8,12 +8,11 @@ package udp
import (
"fmt"
"log/slog"
"net"
"net/netip"
"syscall"
"log/slog"
"golang.org/x/sys/unix"
)
+57
View File
@@ -0,0 +1,57 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
package udp
import (
"log/slog"
"sync"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/wfp"
)
// wrapWithWDFBypass wraps a Conn so that the first ReloadConfig consults listen.windows_bypass_wdf
// and installs a WFP PERMIT filter for the listener's bound UDP port. The session is released when Close runs.
func wrapWithWDFBypass(l *slog.Logger, conn Conn) Conn {
return &bypassConn{Conn: conn, l: l}
}
type bypassConn struct {
Conn
l *slog.Logger
installOnce sync.Once
session *wfp.Session
}
func (b *bypassConn) ReloadConfig(c *config.C) {
b.installOnce.Do(func() {
if !c.GetBool("listen.windows_bypass_wdf", true) {
return
}
addr, err := b.Conn.LocalAddr()
if err != nil {
b.l.Warn("Failed to query listener port for WFP bypass", "error", err)
return
}
s, err := wfp.PermitUDPPort(addr.Port())
if err != nil {
b.l.Warn("Failed to install WFP bypass filters for listener", "error", err)
return
}
b.l.Info("Installed WFP filters bypassing Windows Defender Firewall on UDP listener port",
"port", addr.Port())
b.session = s
})
b.Conn.ReloadConfig(c)
}
func (b *bypassConn) Close() error {
if b.session != nil {
b.session.Close()
b.session = nil
}
return b.Conn.Close()
}
+11
View File
@@ -0,0 +1,11 @@
//go:build !e2e_testing
// +build !e2e_testing
package udp
import "log/slog"
// wrapWithWDFBypass is a no-op on windows-386 since we don't currently build for it.
func wrapWithWDFBypass(_ *slog.Logger, conn Conn) Conn {
return conn
}
+1 -2
View File
@@ -7,12 +7,11 @@ package udp
import (
"fmt"
"log/slog"
"net"
"net/netip"
"syscall"
"log/slog"
"golang.org/x/sys/unix"
)
+50 -13
View File
@@ -21,17 +21,48 @@ type Packet struct {
Data []byte
}
// Copy returns a fresh *Packet (from the freelist) with a duplicate Data buffer.
func (u *Packet) Copy() *Packet {
n := &Packet{
To: u.To,
From: u.From,
Data: make([]byte, len(u.Data)),
n := acquirePacket()
n.To = u.To
n.From = u.From
if cap(n.Data) < len(u.Data) {
n.Data = make([]byte, len(u.Data))
} else {
n.Data = n.Data[:len(u.Data)]
}
copy(n.Data, u.Data)
return n
}
// Release returns p to the harness packet freelist.
// Callers that pull a *Packet from Get / TxPackets must Release when done.
// Channel-backed instead of sync.Pool because sync.Pool's per-P caches drain badly under cross-goroutine Get/Put,
// and putting a []byte in a Pool escapes the slice header to heap.
func (p *Packet) Release() {
if p == nil {
return
}
p.Data = p.Data[:0]
select {
case packetFreelist <- p:
default:
// Freelist full; drop the *Packet for the GC.
}
}
// packetFreelist retains *Packet structs (and their backing Data arrays) so steady-state allocation drops to zero.
var packetFreelist = make(chan *Packet, 64)
func acquirePacket() *Packet {
select {
case p := <-packetFreelist:
return p
default:
return &Packet{}
}
}
type TesterConn struct {
Addr netip.AddrPort
@@ -64,13 +95,15 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn,
// this is an encrypted packet or a handshake message in most cases
// packets were transmitted from another nebula node, you can send them with Tun.Send
func (u *TesterConn) Send(packet *Packet) {
h := &header.H{}
if u.l.Enabled(context.Background(), slog.LevelDebug) {
// Parse the header only under debug logging, otherwise the
// allocation would show up in every Send call.
var h header.H
if err := h.Parse(packet.Data); err != nil {
panic(err)
}
if u.l.Enabled(context.Background(), slog.LevelDebug) {
u.l.Debug("UDP receiving injected packet",
"header", h,
"header", &h,
"udpAddr", packet.From,
"dataLen", len(packet.Data),
)
@@ -107,15 +140,18 @@ func (u *TesterConn) Get(block bool) *Packet {
//********************************************************************************************************************//
func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
p := &Packet{
Data: make([]byte, len(b), len(b)),
From: u.Addr,
To: addr,
p := acquirePacket()
if cap(p.Data) < len(b) {
p.Data = make([]byte, len(b))
} else {
p.Data = p.Data[:len(b)]
}
copy(p.Data, b)
p.From = u.Addr
p.To = addr
select {
case <-u.done:
p.Release()
return io.ErrClosedPipe
case u.TxPackets <- p:
return nil
@@ -129,6 +165,7 @@ func (u *TesterConn) ListenOut(r EncReader) error {
return os.ErrClosed
case p := <-u.RxPackets:
r(p.From, p.Data)
p.Release()
}
}
}
+9 -4
View File
@@ -19,13 +19,18 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int)
return nil, fmt.Errorf("multiple udp listeners not supported on windows")
}
var conn Conn
rc, err := NewRIOListener(l, ip, port)
if err == nil {
return rc, nil
}
conn = rc
} else {
l.Error("Falling back to standard udp sockets", "error", err)
return NewGenericListener(l, ip, port, multi, batch)
conn, err = NewGenericListener(l, ip, port, multi, batch)
if err != nil {
return nil, err
}
}
return wrapWithWDFBypass(l, conn), nil
}
func NewListenConfig(multi bool) net.ListenConfig {
+377
View File
@@ -0,0 +1,377 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
// Package wfp installs Windows Filtering Platform (WFP) PERMIT filters in a dynamic, session-scoped sublayer.
// Because WFP sits below Windows Defender Firewall, a high-weight permit at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4/V6 lets
// the matching inbound traffic through regardless of WDF rules.
//
// Each Session owns its own engine handle. When the handle closes, every dynamic object added during the session
// is auto-deleted by Windows, so there are no orphaned filters.
//
// Type definitions and constants are derived from the wireguard-windows firewall package (MIT).
// Only the subset we exercise is reproduced.
package wfp
import (
"fmt"
"unsafe"
"golang.org/x/sys/windows"
)
// FWPM layer GUIDs (fwpmu.h).
//
// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = e1cd9fe7-f4b5-4273-96c0-592e487b8650
// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = a3b42c97-9f04-4672-b87e-cee9c483257f
var (
fwpmLayerAleAuthRecvAcceptV4 = windows.GUID{
Data1: 0xe1cd9fe7, Data2: 0xf4b5, Data3: 0x4273,
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
}
fwpmLayerAleAuthRecvAcceptV6 = windows.GUID{
Data1: 0xa3b42c97, Data2: 0x9f04, Data3: 0x4672,
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
}
)
// FWPM_CONDITION_IP_LOCAL_INTERFACE = 4cd62a49-59c3-4969-b7f3-bda5d32890a4
var fwpmConditionIPLocalInterface = windows.GUID{
Data1: 0x4cd62a49, Data2: 0x59c3, Data3: 0x4969,
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
}
// FWPM_CONDITION_IP_PROTOCOL = 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
var fwpmConditionIPProtocol = windows.GUID{
Data1: 0x3971ef2b, Data2: 0x623e, Data3: 0x4f9a,
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
}
// FWPM_CONDITION_IP_LOCAL_PORT = 0c1ba1af-5765-453f-af22-a8f791ac775b
var fwpmConditionIPLocalPort = windows.GUID{
Data1: 0x0c1ba1af, Data2: 0x5765, Data3: 0x453f,
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
}
// IPPROTO_UDP from in.h.
const ipprotoUDP uint8 = 17
// FWP_ACTION_TYPE values (fwptypes.h). PERMIT is terminating.
const fwpActionPermit uint32 = 0x00001002 // 0x2 | FWP_ACTION_FLAG_TERMINATING(0x1000)
// FWP_DATA_TYPE values we use.
const (
fwpEmpty uint32 = 0
fwpUint8 uint32 = 1
fwpUint16 uint32 = 2
fwpUint64 uint32 = 4
)
// FWP_MATCH_TYPE values.
const fwpMatchEqual uint32 = 0
// FWPM_SESSION flags.
const fwpmSessionFlagDynamic uint32 = 0x1
// FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT prevents lower-priority filters in other sublayers,
// notably Windows Defender Firewall's MPSSVC_WF sublayer, which shares our 0xFFFF weight from overriding this PERMIT.
// Without it, a default WDF block at the same sublayer weight can still win arbitration.
const fwpmFilterFlagClearActionRight uint32 = 0x8
// RPC authentication.
// RPC_C_AUTHN_WINNT works on workgroup machines with no domain context
// RPC_C_AUTHN_DEFAULT falls back through a chain that can land on something WFP doesn't accept on a fresh box.
const rpcCAuthnWinNT uint32 = 10
// fwpByteBlob (FWP_BYTE_BLOB). 16 bytes on 64-bit.
type fwpByteBlob struct {
size uint32
_ uint32 // padding
data *uint8
}
// fwpValue0 / FWP_CONDITION_VALUE0 layout. 16 bytes on 64-bit.
// The union is pointer-sized; types <= 32 bits (UINT8/16/32, INT8/16/32, float) live inline in the low bytes
// of `value`, while UINT64/INT64/double and aggregate types are stored *by pointer*, even on 64-bit, where the
// union member is declared as UINT64*. So when populating an FWP_UINT64 condition, pass
// uintptr(unsafe.Pointer(&luidVar)) instead of the LUID inline.
type fwpValue0 struct {
type_ uint32
_ uint32 // padding before union to 8-byte alignment
value uintptr
}
// fwpmDisplayData0 / FWPM_DISPLAY_DATA0. 16 bytes on 64-bit.
type fwpmDisplayData0 struct {
name *uint16
description *uint16
}
// fwpmAction0 / FWPM_ACTION0. 20 bytes; no leading padding because actionType
// is uint32 and GUID's first field is uint32.
type fwpmAction0 struct {
actionType uint32
filterType windows.GUID
}
// fwpmFilterCondition0. 40 bytes on 64-bit.
type fwpmFilterCondition0 struct {
fieldKey windows.GUID // 16
matchType uint32 // 4
_ uint32 // 4 padding
conditionValue fwpValue0 // 16
}
// fwpmFilter0. 200 bytes on 64-bit.
type fwpmFilter0 struct {
filterKey windows.GUID
displayData fwpmDisplayData0
flags uint32
_ uint32 // padding before *GUID
providerKey *windows.GUID
providerData fwpByteBlob
layerKey windows.GUID
subLayerKey windows.GUID
weight fwpValue0
numFilterConditions uint32
_ uint32 // padding before pointer
filterCondition *fwpmFilterCondition0
action fwpmAction0
_ [4]byte // layout correction
providerContextKey windows.GUID
reserved *windows.GUID
filterID uint64
effectiveWeight fwpValue0
}
// fwpmSublayer0. 72 bytes on 64-bit.
type fwpmSublayer0 struct {
subLayerKey windows.GUID
displayData fwpmDisplayData0
flags uint32
_ uint32 // padding before *GUID
providerKey *windows.GUID
providerData fwpByteBlob
weight uint16
_ [6]byte // padding to 72 bytes
}
// fwpmSession0. 72 bytes on 64-bit.
type fwpmSession0 struct {
sessionKey windows.GUID
displayData fwpmDisplayData0
flags uint32
txnWaitTimeoutInMSec uint32
processId uint32
_ uint32 // padding before *SID
sid *windows.SID
username *uint16
kernelMode uint8
_ [7]byte // tail padding
}
// fwpuclnt.dll bindings. Only the calls we use.
var (
modFwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
procFwpmEngineOpen0 = modFwpuclnt.NewProc("FwpmEngineOpen0")
procFwpmEngineClose0 = modFwpuclnt.NewProc("FwpmEngineClose0")
procFwpmSubLayerAdd0 = modFwpuclnt.NewProc("FwpmSubLayerAdd0")
procFwpmFilterAdd0 = modFwpuclnt.NewProc("FwpmFilterAdd0")
)
// Session holds the WFP engine handle for a single bypass operation. The handle owns a dynamic session:
// when it is closed, every WFP object added during the session (sublayer + filters) is automatically deleted by
// Windows. That gives us correct cleanup even if the host process is killed hard between Permit* and Close.
type Session struct {
engine uintptr
}
// Close releases the engine handle. Windows deletes every dynamic object (sublayer + filters) the session installed.
// Safe to call on a nil receiver.
func (s *Session) Close() {
if s == nil || s.engine == 0 {
return
}
procFwpmEngineClose0.Call(s.engine)
s.engine = 0
}
// PermitInterface installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to the given network
// interface LUID. Inbound traffic on that interface bypasses Windows Defender Firewall.
func PermitInterface(luid uint64) (*Session, error) {
s, sublayerKey, err := newSession()
if err != nil {
return nil, err
}
if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, luid); err != nil {
s.Close()
return nil, fmt.Errorf("add v4 filter: %w", err)
}
if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, luid); err != nil {
s.Close()
return nil, fmt.Errorf("add v6 filter: %w", err)
}
return s, nil
}
// PermitUDPPort installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to UDP traffic with the
// given local port. Inbound UDP to that port on any interface bypasses Windows Defender Firewall.
func PermitUDPPort(port uint16) (*Session, error) {
s, sublayerKey, err := newSession()
if err != nil {
return nil, err
}
if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, port); err != nil {
s.Close()
return nil, fmt.Errorf("add v4 filter: %w", err)
}
if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, port); err != nil {
s.Close()
return nil, fmt.Errorf("add v6 filter: %w", err)
}
return s, nil
}
func newSession() (*Session, windows.GUID, error) {
engine, err := openDynamicEngine()
if err != nil {
return nil, windows.GUID{}, err
}
sublayerKey, err := registerSublayer(engine)
if err != nil {
procFwpmEngineClose0.Call(engine)
return nil, windows.GUID{}, err
}
return &Session{engine: engine}, sublayerKey, nil
}
func openDynamicEngine() (uintptr, error) {
session := fwpmSession0{flags: fwpmSessionFlagDynamic}
var engine uintptr
r1, _, _ := procFwpmEngineOpen0.Call(
0, // serverName == NULL (local)
uintptr(rpcCAuthnWinNT),
0, // authIdentity == NULL
uintptr(unsafe.Pointer(&session)),
uintptr(unsafe.Pointer(&engine)),
)
if r1 != 0 {
return 0, fmt.Errorf("FwpmEngineOpen0: 0x%x", r1)
}
return engine, nil
}
// registerSublayer adds a session-scoped sublayer with a freshly generated GUID, weight 0xFFFF so its filters arbitrate
// above WDF's default sublayer. The sublayer is dynamic (no PERSISTENT flag) and goes away when the engine handle closes.
func registerSublayer(engine uintptr) (windows.GUID, error) {
key, err := windows.GenerateGUID()
if err != nil {
return windows.GUID{}, fmt.Errorf("GenerateGUID for sublayer: %w", err)
}
name, _ := windows.UTF16PtrFromString("Nebula WDF bypass sublayer")
desc, _ := windows.UTF16PtrFromString("Permit filters bypassing Windows Defender Firewall")
sl := fwpmSublayer0{
subLayerKey: key,
displayData: fwpmDisplayData0{name: name, description: desc},
weight: 0xFFFF,
}
r1, _, _ := procFwpmSubLayerAdd0.Call(
engine,
uintptr(unsafe.Pointer(&sl)),
0, // sd == NULL
)
if r1 != 0 {
return windows.GUID{}, fmt.Errorf("FwpmSubLayerAdd0: 0x%x", r1)
}
return key, nil
}
func addInterfaceFilter(engine uintptr, sublayerKey, layer windows.GUID, luid uint64) error {
name, _ := windows.UTF16PtrFromString("Nebula allow interface inbound")
desc, _ := windows.UTF16PtrFromString("Permits inbound traffic on a nebula interface")
// luid must remain addressable through the syscall -- FWP_UINT64 is stored
// by pointer in the FWP_VALUE0 union.
cond := fwpmFilterCondition0{
fieldKey: fwpmConditionIPLocalInterface,
matchType: fwpMatchEqual,
conditionValue: fwpValue0{
type_: fwpUint64,
value: uintptr(unsafe.Pointer(&luid)),
},
}
filter := fwpmFilter0{
// filterKey left zero: WFP assigns one when the filter is added.
displayData: fwpmDisplayData0{name: name, description: desc},
flags: fwpmFilterFlagClearActionRight,
layerKey: layer,
subLayerKey: sublayerKey,
weight: fwpValue0{type_: fwpUint8, value: uintptr(15)},
numFilterConditions: 1,
filterCondition: &cond,
action: fwpmAction0{actionType: fwpActionPermit},
}
r1, _, _ := procFwpmFilterAdd0.Call(
engine,
uintptr(unsafe.Pointer(&filter)),
0, // sd == NULL
0, // id == NULL
)
if r1 != 0 {
return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1)
}
return nil
}
// addUDPPortFilter installs a PERMIT filter that matches (IP_PROTOCOL == UDP) AND (IP_LOCAL_PORT == port).
// FWP_UINT8 and FWP_UINT16 are <= 32 bits so they live inline in the FWP_VALUE0 union.
func addUDPPortFilter(engine uintptr, sublayerKey, layer windows.GUID, port uint16) error {
name, _ := windows.UTF16PtrFromString("Nebula allow UDP port inbound")
desc, _ := windows.UTF16PtrFromString("Permits inbound UDP to a nebula listener port")
conds := [2]fwpmFilterCondition0{
{
fieldKey: fwpmConditionIPProtocol,
matchType: fwpMatchEqual,
conditionValue: fwpValue0{
type_: fwpUint8,
value: uintptr(ipprotoUDP),
},
},
{
fieldKey: fwpmConditionIPLocalPort,
matchType: fwpMatchEqual,
conditionValue: fwpValue0{
type_: fwpUint16,
value: uintptr(port),
},
},
}
filter := fwpmFilter0{
displayData: fwpmDisplayData0{name: name, description: desc},
flags: fwpmFilterFlagClearActionRight,
layerKey: layer,
subLayerKey: sublayerKey,
weight: fwpValue0{type_: fwpUint8, value: uintptr(15)},
numFilterConditions: 2,
filterCondition: &conds[0],
action: fwpmAction0{actionType: fwpActionPermit},
}
r1, _, _ := procFwpmFilterAdd0.Call(
engine,
uintptr(unsafe.Pointer(&filter)),
0,
0,
)
if r1 != 0 {
return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1)
}
return nil
}