Merge remote-tracking branch 'origin/master' into multiport

This commit is contained in:
Wade Simmons
2026-05-27 22:24:53 -04:00
83 changed files with 4388 additions and 651 deletions
+113
View File
@@ -0,0 +1,113 @@
name: Code-sign Windows binaries
description: >
Sign every .exe under a given path in place via the DefinedNet code-signer
Lambda. If `role` or `bucket` is empty, logs a notice and skips signing so
forks and dev branches without AWS access still produce usable builds.
inputs:
path:
description: "Directory whose .exe files should be signed in place"
required: true
role:
description: "IAM role ARN to assume via OIDC; empty disables signing"
required: false
default: ""
bucket:
description: "S3 staging bucket the code-signer Lambda reads from; empty disables signing"
required: false
default: ""
region:
description: "AWS region for the role and Lambda"
required: false
default: "us-east-2"
function-name:
description: "Code-signer Lambda function name"
required: false
default: "code-signer"
key-prefix:
description: "S3 key prefix the caller is authorized to write under"
required: false
default: "code-signing/slackhq/nebula"
runs:
using: composite
steps:
- name: Skip notice
if: inputs.role == '' || inputs.bucket == ''
shell: sh
run: echo "::notice::code-signer role or bucket not set; skipping code signing."
- name: Configure AWS credentials
if: inputs.role != '' && inputs.bucket != ''
uses: aws-actions/configure-aws-credentials@v6
with:
role-to-assume: ${{ inputs.role }}
aws-region: ${{ inputs.region }}
# Default is 12 retries to ride out IAM trust-policy propagation; once
# the role is stable we want a real misconfiguration to fail fast.
retry-max-attempts: 5
- name: Sign .exe files
if: inputs.role != '' && inputs.bucket != ''
shell: sh
env:
SIGN_PATH: ${{ inputs.path }}
BUCKET: ${{ inputs.bucket }}
FUNCTION_NAME: ${{ inputs.function-name }}
KEY_PREFIX: ${{ inputs.key-prefix }}
run: |
set -eu
RUN="${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
find "$SIGN_PATH" -name '*.exe' -print | while read -r path
do
rel=${path#"$SIGN_PATH"/}
file=$(basename "$path")
name=${file%.exe}
prefix="${KEY_PREFIX}/${RUN}"
src="${prefix}/unsigned/${rel}"
dst="${prefix}/signed/${rel}"
echo "::group::Sign ${rel}"
echo "Uploading unsigned to s3://${BUCKET}/${src}"
aws s3 cp --no-progress "$path" "s3://${BUCKET}/${src}" >/dev/null
echo "Invoking ${FUNCTION_NAME} Lambda"
payload=$(jq -nc \
--arg s "$src" \
--arg d "$dst" \
--arg p "$name" \
'{source_key: $s, dest_key: $d, program_name: $p}')
meta=$(aws lambda invoke \
--function-name "$FUNCTION_NAME" \
--cli-binary-format raw-in-base64-out \
--payload "$payload" \
--output json \
/tmp/sign-resp.json)
if echo "$meta" | jq -e '.FunctionError != null' >/dev/null
then
echo "::endgroup::"
echo "::error::code-signer Lambda failed for ${rel}"
cat /tmp/sign-resp.json >&2
exit 1
fi
echo "Downloading signed back to ${path}"
aws s3 cp --no-progress "s3://${BUCKET}/${dst}" "$path" >/dev/null
aws s3 rm "s3://${BUCKET}/${src}" >/dev/null 2>&1 || true
aws s3 rm "s3://${BUCKET}/${dst}" >/dev/null 2>&1 || true
# Sanity-check the bytes we got back actually carry an Authenticode
# signature that this machine can validate end to end.
status=$(powershell -NoProfile -Command "(Get-AuthenticodeSignature -FilePath '$path').Status" | tr -d '\r')
if [ "$status" != "Valid" ]
then
echo "::endgroup::"
echo "::error::${rel} signature status: ${status} (expected Valid)"
exit 1
fi
echo "Signed ${rel} (sha256=$(jq -r '.sha256' /tmp/sign-resp.json), status=${status})"
echo "::endgroup::"
done
-34
View File
@@ -1,34 +0,0 @@
name: gofmt
on:
push:
branches:
- master
pull_request:
paths:
- '.github/workflows/gofmt.yml'
- '**.go'
jobs:
gofmt:
name: Run gofmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Install goimports
run: |
go install golang.org/x/tools/cmd/goimports@latest
- name: gofmt
run: |
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ]
then
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d
exit 1
fi
+18 -8
View File
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release
- name: Upload artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: linux-latest
path: release
@@ -32,6 +32,9 @@ jobs:
build-windows:
name: Build Windows
runs-on: windows-latest
permissions:
id-token: write
contents: read
steps:
- uses: actions/checkout@v6
@@ -54,8 +57,15 @@ jobs:
mkdir build\dist\windows
mv dist\windows\wintun build\dist\windows\
- name: Code-sign
uses: ./.github/actions/code-sign
with:
path: build
role: ${{ secrets.DEFINED_CODE_SIGNER_ROLE }}
bucket: ${{ secrets.DEFINED_CODE_SIGNER_BUCKET }}
- name: Upload artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: windows-latest
path: build
@@ -75,7 +85,7 @@ jobs:
- name: Import certificates
if: env.HAS_SIGNING_CREDS == 'true'
uses: Apple-Actions/import-codesign-certs@v6
uses: Apple-Actions/import-codesign-certs@v7
with:
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
@@ -104,7 +114,7 @@ jobs:
fi
- name: Upload artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: darwin-latest
path: ./release/*
@@ -128,21 +138,21 @@ jobs:
- name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
name: linux-latest
path: artifacts
- name: Login to Docker Hub
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
username: ${{ vars.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v4
- name: Build and push images
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
@@ -163,7 +173,7 @@ jobs:
- uses: actions/checkout@v6
- name: Download artifacts
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
path: artifacts
+81 -16
View File
@@ -14,10 +14,18 @@ on:
- 'go.sum'
jobs:
smoke-extra:
smoke-extra-libvirt:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: Run extra smoke tests
name: ${{ matrix.target }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
target:
- freebsd-amd64
- openbsd-amd64
- netbsd-amd64
- linux-amd64-ipv6disable
env:
VAGRANT_DEFAULT_PROVIDER: libvirt
steps:
@@ -40,28 +48,85 @@ jobs:
sudo chmod 666 /var/run/libvirt/libvirt-sock
vagrant plugin install vagrant-libvirt
- name: freebsd-amd64
run: make smoke-vagrant/freebsd-amd64
- name: ${{ matrix.target }}
run: make smoke-vagrant/${{ matrix.target }}
- name: openbsd-amd64
run: make smoke-vagrant/openbsd-amd64
timeout-minutes: 30
- name: netbsd-amd64
run: make smoke-vagrant/netbsd-amd64
# linux-386 needs VirtualBox, which conflicts with KVM/libvirt -- isolated job.
smoke-extra-virtualbox:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: linux-386
runs-on: ubuntu-latest
env:
VAGRANT_DEFAULT_PROVIDER: virtualbox
steps:
- name: linux-amd64-ipv6disable
run: make smoke-vagrant/linux-amd64-ipv6disable
- uses: actions/checkout@v6
# linux-386 runs last because it requires disabling KVM to use VirtualBox,
# which prevents libvirt (used by the other tests) from working after this point.
- name: install virtualbox for i386 test
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: add hashicorp source
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
- name: install vagrant and virtualbox
run: |
sudo apt-get install -y virtualbox
sudo apt-get update && sudo apt-get install -y vagrant virtualbox
sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true
- name: linux-386
env:
VAGRANT_DEFAULT_PROVIDER: virtualbox
run: make smoke-vagrant/linux-386
timeout-minutes: 30
smoke-windows:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: Run windows smoke test
runs-on: windows-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
# WSL2 + Ubuntu so the smoke can run a real linux peer with its own
# netns. iputils-ping is needed for the in-WSL ping check. WSL1 has no
# real kernel and would lack /dev/net/tun, so we have to force WSL2.
- uses: Vampire/setup-wsl@v3
with:
distribution: Ubuntu-24.04
additional-packages: iputils-ping iproute2
# Vampire/setup-wsl provisions WSL1 even when the WSL2 platform is present.
# Convert the distro to WSL2 explicitly before we try to use /dev/net/tun.
- name: convert distro to WSL2
shell: pwsh
run: |
wsl --set-version Ubuntu-24.04 2
wsl --shutdown
wsl --list --verbose
- name: build windows nebula
run: make bin-windows
- name: build linux nebula for WSL
shell: bash
env:
GOOS: linux
GOARCH: amd64
run: |
mkdir -p build/linux-amd64
go build -o build/linux-amd64/nebula ./cmd/nebula
- name: run smoke-windows
shell: pwsh
working-directory: ./.github/workflows/smoke
run: ./smoke-windows.ps1
timeout-minutes: 15
+272
View File
@@ -0,0 +1,272 @@
#!/usr/bin/env pwsh
# Windows smoke test for the nebula tun + UDP + NLM code paths.
#
# Topology:
# - lighthouse runs natively on the Windows host (wintun + windows UDP)
# - peer runs inside WSL2 (Linux build of nebula, /dev/net/tun)
#
# WSL2 gives us a real netns boundary so the loopback fast-path on Windows
# does not short-circuit the overlay -- when WSL pings the lighthouse VPN IP,
# Linux has no idea that IP is local to the Windows host, so the packet is
# forced through nebula. Same in reverse.
$ErrorActionPreference = 'Stop'
# wsl.exe emits UTF-16 LE by default which PowerShell reads as bytes, mangling
# every captured string. WSL_UTF8 makes wsl.exe emit UTF-8 instead.
$env:WSL_UTF8 = '1'
$RepoRoot = Resolve-Path "$PSScriptRoot\..\..\.."
$Nebula = Join-Path $RepoRoot 'nebula.exe'
$NebulaCert = Join-Path $RepoRoot 'nebula-cert.exe'
$NebulaLinux = Join-Path $RepoRoot 'build\linux-amd64\nebula'
if (-not (Test-Path $Nebula)) { throw "missing $Nebula; run 'make bin-windows' first" }
if (-not (Test-Path $NebulaCert)) { throw "missing $NebulaCert; run 'make bin-windows' first" }
if (-not (Test-Path $NebulaLinux)) { throw "missing $NebulaLinux; build the linux nebula first" }
# Matches the distro installed by Vampire/setup-wsl in smoke-extra.yml.
$Distro = 'Ubuntu-24.04'
$listed = (wsl --list --quiet 2>$null) -join "`n"
if ($listed -notmatch [regex]::Escape($Distro)) {
throw "WSL distro $Distro not registered. Got: $listed"
}
Write-Host "Using WSL distro: $Distro"
# Windows host as seen from inside WSL: WSL's default-route gateway. We extract
# it with a regex rather than awk fields so PowerShell does not eat any '$N'
# tokens, and tabs/double-spaces in `ip route` output do not confuse a cut.
$ipCmd = 'ip route show default | grep -oE "([0-9]+\.){3}[0-9]+" | head -1'
$WindowsIp = (wsl -d $Distro -- bash -c $ipCmd).Trim()
if (-not $WindowsIp) { throw "could not determine Windows host IP from WSL" }
Write-Host "Windows host IP from WSL: $WindowsIp"
$WorkDir = Join-Path $env:TEMP 'nebula-smoke-windows'
if (Test-Path $WorkDir) { Remove-Item -Recurse -Force $WorkDir }
New-Item -ItemType Directory -Path $WorkDir | Out-Null
$WslDir = '/tmp/nebula-smoke'
wsl -d $Distro -- bash -c "rm -rf $WslDir && mkdir -p $WslDir" | Out-Null
$DevName = 'nebula-smoke'
$Ip1 = '192.168.241.1'
$Ip2 = '192.168.241.2'
$Port = 4242
& $NebulaCert ca -name 'smoke-ca' -out-crt "$WorkDir\ca.crt" -out-key "$WorkDir\ca.key"
if ($LASTEXITCODE -ne 0) { throw "nebula-cert ca failed (exit $LASTEXITCODE)" }
& $NebulaCert sign -name 'lighthouse' -networks "$Ip1/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\lighthouse.crt" -out-key "$WorkDir\lighthouse.key"
if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign lighthouse failed (exit $LASTEXITCODE)" }
& $NebulaCert sign -name 'peer' -networks "$Ip2/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\peer.crt" -out-key "$WorkDir\peer.key"
if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign peer failed (exit $LASTEXITCODE)" }
# Windows lighthouse config.
@"
pki:
ca: $WorkDir\ca.crt
cert: $WorkDir\lighthouse.crt
key: $WorkDir\lighthouse.key
static_host_map: {}
lighthouse:
am_lighthouse: true
interval: 60
hosts: []
listen:
host: 0.0.0.0
port: $Port
tun:
disabled: false
dev: $DevName
drop_local_broadcast: false
drop_multicast: false
tx_queue: 500
mtu: 1300
network_category: private
logging:
level: info
format: text
firewall:
outbound_action: drop
inbound_action: drop
conntrack:
tcp_timeout: 12m
udp_timeout: 3m
default_timeout: 10m
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
"@ | Out-File -FilePath "$WorkDir\lighthouse.yml" -Encoding utf8
# WSL peer config (paths are POSIX, deliberately).
@"
pki:
ca: $WslDir/ca.crt
cert: $WslDir/peer.crt
key: $WslDir/peer.key
static_host_map:
"${Ip1}": ["${WindowsIp}:$Port"]
lighthouse:
am_lighthouse: false
interval: 60
hosts:
- "${Ip1}"
listen:
host: 0.0.0.0
port: 0
tun:
disabled: false
dev: nebula1
drop_local_broadcast: false
drop_multicast: false
tx_queue: 500
mtu: 1300
logging:
level: info
format: text
firewall:
outbound_action: drop
inbound_action: drop
conntrack:
tcp_timeout: 12m
udp_timeout: 3m
default_timeout: 10m
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
"@ | Out-File -FilePath "$WorkDir\peer.yml" -Encoding utf8
# Stage WSL artifacts. Convert Windows paths to WSL paths ourselves rather than
# calling `wslpath`, because PowerShell's argument-passing to external EXEs
# strips backslashes from path arguments in ways that are hard to escape around.
function ConvertTo-WslPath {
param([string]$WindowsPath)
if ($WindowsPath -notmatch '^([A-Za-z]):\\(.*)$') {
throw "cannot convert path to WSL: $WindowsPath"
}
return "/mnt/$($matches[1].ToLower())/$($matches[2].Replace('\','/'))"
}
$WslWorkDir = ConvertTo-WslPath $WorkDir
$WslNebulaPath = ConvertTo-WslPath $NebulaLinux
wsl -d $Distro -- bash -c "cp '$WslWorkDir/ca.crt' '$WslWorkDir/peer.crt' '$WslWorkDir/peer.key' '$WslWorkDir/peer.yml' $WslDir/ && cp '$WslNebulaPath' $WslDir/nebula && chmod +x $WslDir/nebula"
# Make sure WSL has tun support and /dev/net/tun is usable before starting
# nebula. Diagnostics first so a fail here points at the real problem (e.g.
# WSL1 distros do not have a real kernel and will not have tun).
Write-Host '=== WSL diagnostic ==='
wsl --version 2>&1 | Out-Host
wsl --list --verbose 2>&1 | Out-Host
wsl -d $Distro -u root -- uname -a | Out-Host
wsl -d $Distro -u root -- bash -c "modprobe tun 2>&1 || true; mkdir -p /dev/net; [ -c /dev/net/tun ] || mknod /dev/net/tun c 10 200; chmod 600 /dev/net/tun; ls -l /dev/net/tun"
if ($LASTEXITCODE -ne 0) { throw "failed to prepare /dev/net/tun in WSL (TUN support missing?)" }
# Deliberately no New-NetFirewallRule calls here -- nebula's windows_bypass_wdf
# feature is supposed to install WFP permit filters that let inbound traffic
# through Windows Defender Firewall on its own. If this smoke regresses, that
# feature regressed.
$lhOut = Join-Path $WorkDir 'lighthouse.out.log'
$lhErr = Join-Path $WorkDir 'lighthouse.err.log'
$lhProc = Start-Process -FilePath $Nebula -ArgumentList @('-config', "$WorkDir\lighthouse.yml") `
-PassThru -NoNewWindow `
-RedirectStandardOutput $lhOut `
-RedirectStandardError $lhErr
# Run nebula in WSL as root with no sudo + no shell wrapper. PowerShell's
# Start-Process arg quoting mangles `bash -c "..."` strings that contain
# spaces/redirections, so we skip bash entirely and let Start-Process do the
# stdout/stderr capture itself.
$peerOut = Join-Path $WorkDir 'peer.out.log'
$peerErr = Join-Path $WorkDir 'peer.err.log'
$peerProc = Start-Process -FilePath 'wsl' `
-ArgumentList @('-d', $Distro, '-u', 'root', '--', "$WslDir/nebula", '-config', "$WslDir/peer.yml") `
-PassThru -NoNewWindow `
-RedirectStandardOutput $peerOut `
-RedirectStandardError $peerErr
function Wait-Until {
param([scriptblock]$Predicate, [int]$TimeoutSec, [string]$What)
$deadline = (Get-Date).AddSeconds($TimeoutSec)
while ((Get-Date) -lt $deadline) {
if (& $Predicate) { return }
Start-Sleep -Milliseconds 500
}
throw "timed out waiting for: $What"
}
try {
Wait-Until -TimeoutSec 30 -What "windows wintun adapter $DevName with NetworkCategory=Private" -Predicate {
if ($lhProc.HasExited) { throw "lighthouse exited (code $($lhProc.ExitCode)) before tun was ready" }
$p = Get-NetConnectionProfile -InterfaceAlias $DevName -ErrorAction SilentlyContinue
$p -and ("$($p.NetworkCategory)" -ieq 'Private')
}
Write-Host "OK: $DevName NetworkCategory=Private"
Wait-Until -TimeoutSec 30 -What "WSL nebula1 with $Ip2" -Predicate {
if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before tun was ready" }
$r = wsl -d $Distro -u root -- bash -c "ip -o addr show nebula1 2>/dev/null | grep -q 'inet $Ip2' && echo yes"
("$r").Trim() -eq 'yes'
}
Write-Host "OK: WSL nebula1 has $Ip2"
Wait-Until -TimeoutSec 30 -What "ping from WSL peer to windows lighthouse ($Ip1)" -Predicate {
if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before ping succeeded" }
$r = wsl -d $Distro -u root -- bash -c "ping -c1 -W1 $Ip1 >/dev/null 2>&1 && echo OK"
("$r").Trim() -eq 'OK'
}
Write-Host "OK: WSL peer -> windows lighthouse"
Wait-Until -TimeoutSec 30 -What "ping from windows lighthouse to WSL peer ($Ip2)" -Predicate {
$null = & ping.exe -n 1 -w 1000 $Ip2
$LASTEXITCODE -eq 0
}
Write-Host "OK: windows lighthouse -> WSL peer"
Write-Host ''
Write-Host 'All smoke checks passed.'
}
catch {
Write-Host ''
Write-Host '=== lighthouse stdout ==='
Get-Content $lhOut -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== lighthouse stderr ==='
Get-Content $lhErr -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== peer stdout ==='
Get-Content $peerOut -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== peer stderr ==='
Get-Content $peerErr -ErrorAction SilentlyContinue | Out-Host
Write-Host '=== nebula WFP filters ==='
# Dump nebula-installed filters so we can verify they got registered with
# the conditions we expect.
$wfpDump = Join-Path $WorkDir 'wfp.xml'
netsh wfp show filters file=$wfpDump 2>&1 | Out-Null
if (Test-Path $wfpDump) {
Select-String -Path $wfpDump -Pattern 'Nebula' -Context 0,80 -ErrorAction SilentlyContinue | Out-Host
}
throw
}
finally {
if (-not $lhProc.HasExited) {
Stop-Process -Id $lhProc.Id -Force -ErrorAction SilentlyContinue
$lhProc.WaitForExit(5000) | Out-Null
}
wsl -d $Distro -u root -- bash -c "pkill -f $WslDir/nebula 2>/dev/null; true" | Out-Null
# pkill returns 1 when no match and wsl propagates that; the smoke is done
# so we don't want it to leak into the script's exit code.
$global:LASTEXITCODE = 0
if ($peerProc -and -not $peerProc.HasExited) {
Stop-Process -Id $peerProc.Id -Force -ErrorAction SilentlyContinue
}
}
@@ -1,7 +1,7 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
Vagrant.configure("2") do |config|
config.vm.box = "generic/netbsd9"
config.vm.box = "DefinedNet/netbsd10"
config.vm.synced_folder "../build", "/nebula", type: "rsync"
end
+95 -77
View File
@@ -13,8 +13,8 @@ on:
- 'go.sum'
jobs:
test-linux:
name: Build all and test on ubuntu-linux
static:
name: Static checks
runs-on: ubuntu-latest
steps:
@@ -25,8 +25,16 @@ jobs:
go-version: '1.25'
check-latest: true
- name: Build
run: make all
- name: Install goimports
run: go install golang.org/x/tools/cmd/goimports@latest
- name: gofmt
run: |
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ]
then
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d
exit 1
fi
- name: Vet
run: make vet
@@ -36,66 +44,38 @@ jobs:
with:
version: v2.5
- name: Test
run: make test
- name: End 2 end
run: make e2evv
- name: Build test mobile
run: make build-test-mobile
- uses: actions/upload-artifact@v6
with:
name: e2e packet flow linux-latest
path: e2e/mermaid/linux-latest
if-no-files-found: warn
test-linux-boringcrypto:
name: Build and test on linux with boringcrypto
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build
run: make bin-boringcrypto
- name: Test
run: make test-boringcrypto
- name: End 2 end
run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
test-linux-pkcs11:
name: Build and test on linux with pkcs11
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build
run: make bin-pkcs11
- name: Test
run: make test-pkcs11
test:
name: Build and test on ${{ matrix.os }}
name: Test ${{ matrix.name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [windows-latest, macos-latest]
include:
- name: linux
os: ubuntu-latest
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
test-cmd: make test
e2e-cmd: make e2evv
- name: linux-boringcrypto
os: ubuntu-latest
build-cmd: make bin-boringcrypto
test-cmd: make test-boringcrypto
e2e-cmd: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
- name: linux-pkcs11
os: ubuntu-latest
build-cmd: make bin-pkcs11
test-cmd: make test-pkcs11
e2e-cmd: ''
- name: macos
os: macos-latest
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
test-cmd: make test
e2e-cmd: make e2evv
- name: windows
os: windows-latest
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
test-cmd: make test
e2e-cmd: make e2evv
steps:
- uses: actions/checkout@v6
@@ -105,28 +85,66 @@ jobs:
go-version: '1.25'
check-latest: true
- name: Build nebula
run: go build ./cmd/nebula
- name: Build
run: ${{ matrix.build-cmd }}
- name: Build nebula-cert
run: go build ./cmd/nebula-cert
- name: Vet
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.5
- name: Cross-build darwin-amd64
if: matrix.name == 'macos'
run: GOARCH=amd64 go build -o /tmp/nebula-amd64 ./cmd/nebula && GOARCH=amd64 go build -o /tmp/nebula-cert-amd64 ./cmd/nebula-cert
- name: Test
run: make test
run: ${{ matrix.test-cmd }}
- name: End 2 end
run: make e2evv
if: matrix.e2e-cmd != ''
run: ${{ matrix.e2e-cmd }}
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
if: matrix.e2e-cmd != '' && always()
with:
name: e2e packet flow ${{ matrix.os }}
path: e2e/mermaid/${{ matrix.os }}
name: e2e packet flow ${{ matrix.name }}
path: e2e/mermaid/
if-no-files-found: warn
cross-build:
name: Cross-build ${{ matrix.name }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- {name: linux-arm, make-target: all-cross-linux-arm}
- {name: linux-mips, make-target: all-cross-linux-mips}
- {name: linux-other, make-target: all-cross-linux-other}
- {name: freebsd, make-target: all-freebsd}
- {name: openbsd, make-target: all-openbsd}
- {name: netbsd, make-target: all-netbsd}
- {name: windows, make-target: all-cross-windows}
- {name: mobile, make-target: build-test-mobile}
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true
- name: Build ${{ matrix.name }}
run: make -j"$(nproc)" ${{ matrix.make-target }}
finish:
name: CI status
if: always()
needs: [static, test, cross-build]
runs-on: ubuntu-latest
steps:
- name: Fail if any upstream job failed
if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled')
run: |
echo "upstream results: ${{ toJSON(needs) }}"
exit 1
- name: All upstream jobs passed
run: echo "ok"
+42 -1
View File
@@ -60,6 +60,18 @@ ALL = $(ALL_LINUX) \
windows-amd64 \
windows-arm64
# Cross-build shards used by .github/workflows/test.yml — same as ALL_*
# but with the arch that has a native CI runner removed, so the cross-build
# job is not duplicating coverage the native test jobs already give.
ALL_CROSS_LINUX = $(filter-out linux-amd64,$(ALL_LINUX))
# ALL_CROSS_LINUX further split into family sub-shards so each can run on
# its own CI runner in parallel. Union of the three must equal
# ALL_CROSS_LINUX; adding a new linux arch goes into the matching family.
ALL_CROSS_LINUX_ARM = linux-arm-5 linux-arm-6 linux-arm-7 linux-arm64
ALL_CROSS_LINUX_MIPS = linux-mips linux-mipsle linux-mips64 linux-mips64le linux-mips-softfloat
ALL_CROSS_LINUX_OTHER = linux-386 linux-ppc64le linux-riscv64 linux-loong64
e2e:
$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
@@ -82,6 +94,35 @@ DOCKER_BIN = build/linux-amd64/nebula build/linux-amd64/nebula-cert
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
all-linux: $(ALL_LINUX:%=build/%/nebula) $(ALL_LINUX:%=build/%/nebula-cert)
all-freebsd: $(ALL_FREEBSD:%=build/%/nebula) $(ALL_FREEBSD:%=build/%/nebula-cert)
all-openbsd: $(ALL_OPENBSD:%=build/%/nebula) $(ALL_OPENBSD:%=build/%/nebula-cert)
all-netbsd: $(ALL_NETBSD:%=build/%/nebula) $(ALL_NETBSD:%=build/%/nebula-cert)
all-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert build/darwin-arm64/nebula build/darwin-arm64/nebula-cert
all-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe
# CI cross-build shards. darwin-arm64 is covered by the native macos-latest
# job; windows-amd64 is covered by the native windows-latest job; both are
# omitted here to avoid building them a second time. darwin-amd64 stays in
# all-cross-darwin because intel mac is only a labeled/master-time native
# job, so PRs still need cross-build coverage for it.
all-cross-linux: $(ALL_CROSS_LINUX:%=build/%/nebula) $(ALL_CROSS_LINUX:%=build/%/nebula-cert)
all-cross-linux-arm: $(ALL_CROSS_LINUX_ARM:%=build/%/nebula) $(ALL_CROSS_LINUX_ARM:%=build/%/nebula-cert)
all-cross-linux-mips: $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula) $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula-cert)
all-cross-linux-other: $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula) $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula-cert)
all-cross-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
all-cross-windows: build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe
docker: docker/linux-$(shell go env GOARCH)
release: $(ALL:%=build/nebula-%.tar.gz)
@@ -240,5 +281,5 @@ smoke-vagrant/%: bin-docker build/%/nebula
cd .github/workflows/smoke/ && ./smoke-vagrant.sh $*
.FORCE:
.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/%
.PHONY: all all-linux all-freebsd all-openbsd all-netbsd all-darwin all-windows all-cross-linux all-cross-linux-arm all-cross-linux-mips all-cross-linux-other all-cross-darwin all-cross-windows bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/%
.DEFAULT_GOAL := bin
+4
View File
@@ -217,6 +217,10 @@ func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp
return nil, err
}
if signer.Certificate.Curve() != c.Curve() {
return nil, ErrCurveMismatch
}
if signer.Certificate.Expired(now) {
return nil, ErrRootExpired
}
+28
View File
@@ -654,3 +654,31 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}
func TestCertificateV2_CurveMismatch(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.0.0.1/24")
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1}, nil, []string{"test"})
fp, _ := c.Fingerprint()
_, err = caPool.verify(c, time.Now(), fp, c.Issuer())
require.NoError(t, err)
//
c2 := c.(*certificateV2)
c2.curve = Curve_CURVE25519
fp, _ = c.Fingerprint()
_, err = caPool.verify(c, time.Now(), fp, c.Issuer())
require.Error(t, err)
}
+3
View File
@@ -112,6 +112,9 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
}
switch c.details.curve {
case Curve_CURVE25519:
if len(key) != ed25519.PublicKeySize {
return false //avoids a panic internal to ed25519
}
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
+3
View File
@@ -151,6 +151,9 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
switch c.curve {
case Curve_CURVE25519:
if len(key) != ed25519.PublicKeySize {
return false //avoids a panic internal to ed25519
}
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
+1
View File
@@ -22,6 +22,7 @@ var (
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrCertPubkeyPresent = errors.New("certificate has unexpected pubkey present")
ErrCurveMismatch = errors.New("certificate curve does not match CA")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
+10 -4
View File
@@ -13,6 +13,12 @@ import (
"golang.org/x/crypto/ed25519"
)
// testCertNow is the reference "now" used to derive default before/after times
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
// signed with default times can never expire after its CA on a rounding race.
var testCertNow = time.Now().Round(time.Second)
// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
var err error
@@ -34,10 +40,10 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
}
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}
t := &TBSCertificate{
@@ -70,11 +76,11 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}
if len(networks) == 0 {
+10 -4
View File
@@ -14,6 +14,12 @@ import (
"golang.org/x/crypto/ed25519"
)
// testCertNow is the reference "now" used to derive default before/after times
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
// signed with default times can never expire after its CA on a rounding race.
var testCertNow = time.Now().Round(time.Second)
// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
var err error
@@ -35,10 +41,10 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
}
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}
t := &cert.TBSCertificate{
@@ -71,11 +77,11 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}
var pub, priv []byte
+32 -7
View File
@@ -97,6 +97,19 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
return err
}
} else {
// out-key is meaningless under PKCS#11 because the private key never
// leaves the HSM; reject it so we never silently accept or claim a
// stdout slot for it.
outKeySet := false
cf.set.Visit(func(f *flag.Flag) {
if f.Name == "out-key" {
outKeySet = true
}
})
if outKeySet {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
}
if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
return err
@@ -171,12 +184,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
}
}
var claims ioClaims
if err := reserveOutputs(&claims,
"out-key", *cf.outKeyPath,
"out-crt", *cf.outCertPath,
"out-qr", *cf.outQRPath,
); err != nil {
return err
}
var passphrase []byte
if !isP11 && *cf.encryption {
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
errOut.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
@@ -261,14 +283,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
Curve: curve,
}
if !isP11 {
if !isP11 && !isStdio(*cf.outKeyPath) {
if _, err := os.Stat(*cf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
}
}
if _, err := os.Stat(*cf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
if !isStdio(*cf.outCertPath) {
if _, err := os.Stat(*cf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
}
}
var c cert.Certificate
@@ -294,7 +318,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
}
err = os.WriteFile(*cf.outKeyPath, b, 0600)
err = writeOutput(*cf.outKeyPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
@@ -305,7 +329,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("error while marshalling certificate: %s", err)
}
err = os.WriteFile(*cf.outCertPath, b, 0600)
err = writeOutput(*cf.outCertPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err)
}
@@ -316,7 +340,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("error while generating qr code: %s", err)
}
err = os.WriteFile(*cf.outQRPath, b, 0600)
err = writeOutput(*cf.outQRPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
@@ -332,6 +356,7 @@ func caSummary() string {
func caHelp(out io.Writer) {
cf := newCaFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n"))
out.Write([]byte(stdioHelpText))
cf.set.SetOutput(out)
cf.set.PrintDefaults()
}
+72 -7
View File
@@ -27,6 +27,7 @@ func Test_caHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" ca <flags>: create a self signed certificate authority\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -argon-iterations uint\n"+
" \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+
" -argon-memory uint\n"+
@@ -84,7 +85,7 @@ func Test_ca(t *testing.T) {
err: nil,
}
pwPromptOb := "Enter passphrase: "
pwPromptEB := "Enter passphrase: "
// required args
assertHelpError(t, ca(
@@ -168,8 +169,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.NoError(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, pwPromptEB, eb.String())
// test encrypted key with passphrase environment variable
os.Remove(keyF.Name())
@@ -207,8 +208,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, pwPromptEB, eb.String())
// test when user fails to enter a password
os.Remove(keyF.Name())
@@ -217,8 +218,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, strings.Repeat(pwPromptEB, 5), eb.String()) // prompts 5 times before giving up
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
@@ -247,3 +248,67 @@ func Test_ca(t *testing.T) {
os.Remove(keyF.Name())
}
func Test_ca_stdio(t *testing.T) {
nopw := &StubPasswordReader{}
keyF, err := os.CreateTemp("", "ca.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
crtF, err := os.CreateTemp("", "ca.crt")
require.NoError(t, err)
os.Remove(crtF.Name())
defer os.Remove(crtF.Name())
// out-crt on stdout, out-key on disk
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", keyF.Name()}, ob, eb, nopw))
assert.Empty(t, eb.String())
c, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.True(t, c.IsCA())
assert.Equal(t, "test-ca", c.Name())
// out-key on stdout, out-crt on disk
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", crtF.Name(), "-out-key", "-"}, ob, eb, nopw))
assert.Empty(t, eb.String())
_, _, curve, err := cert.UnmarshalSigningPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
// dual stdout is rejected up front
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
require.EqualError(t,
ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", "-"}, ob, eb, nopw),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
// an output conflict combined with -encrypt must error BEFORE prompting
// for a passphrase; pr would record any read attempt
tracker := &trackingPasswordReader{}
ob.Reset()
eb.Reset()
require.EqualError(t,
ca([]string{"-name", "test-ca", "-duration", "1h", "-encrypt", "-out-crt", "-", "-out-key", "-"}, ob, eb, tracker),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Zero(t, tracker.calls, "passphrase prompt should not have been called")
}
type trackingPasswordReader struct {
calls int
}
func (pr *trackingPasswordReader) ReadPassword() ([]byte, error) {
pr.calls++
return []byte(""), nil
}
+13 -2
View File
@@ -42,6 +42,8 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
return err
}
} else if *cf.outKeyPath != "" {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
return err
@@ -69,6 +71,14 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
}
}
var claims ioClaims
if err := reserveOutputs(&claims,
"out-key", *cf.outKeyPath,
"out-pub", *cf.outPubPath,
); err != nil {
return err
}
if isP11 {
p11Client, err := pkclient.FromUrl(*cf.p11url)
if err != nil {
@@ -82,12 +92,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while getting public key: %w", err)
}
} else {
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
err = writeOutput(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
}
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
err = writeOutput(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-pub: %s", err)
}
@@ -102,6 +112,7 @@ func keygenSummary() string {
func keygenHelp(out io.Writer) {
cf := newKeygenFlags()
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
_, _ = out.Write([]byte(stdioHelpText))
cf.set.SetOutput(out)
cf.set.PrintDefaults()
}
+41
View File
@@ -20,6 +20,7 @@ func Test_keygenHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -curve string\n"+
" \tECDH Curve (25519, P256) (default \"25519\")\n"+
" -out-key string\n"+
@@ -93,3 +94,43 @@ func Test_keygen(t *testing.T) {
require.NoError(t, err)
assert.Len(t, lPub, 32)
}
func Test_keygen_stdio(t *testing.T) {
keyF, err := os.CreateTemp("", "test.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
pubF, err := os.CreateTemp("", "test.pub")
require.NoError(t, err)
os.Remove(pubF.Name())
defer os.Remove(pubF.Name())
// out-pub on stdout, out-key on disk
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
require.NoError(t, keygen([]string{"-out-pub", "-", "-out-key", keyF.Name()}, ob, eb))
assert.Empty(t, eb.String())
lPub, _, curve, err := cert.UnmarshalPublicKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, lPub, 32)
// out-key on stdout, out-pub on disk
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
require.NoError(t, keygen([]string{"-out-pub", pubF.Name(), "-out-key", "-"}, ob, eb))
assert.Empty(t, eb.String())
lKey, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, lKey, 32)
// both on stdout is a conflict caught up front
ob.Reset()
eb.Reset()
require.EqualError(t, keygen([]string{"-out-pub", "-", "-out-key", "-"}, ob, eb),
`-out-key and -out-pub both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
}
+3 -1
View File
@@ -22,7 +22,9 @@ func (pr StdinPasswordReader) ReadPassword() ([]byte, error) {
}
password, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
// Terminal echo is off while reading, so the user's Enter key does not
// produce a visible newline. Emit one on stderr to match the prompt.
fmt.Fprintln(os.Stderr)
return password, err
}
+23 -8
View File
@@ -40,11 +40,23 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return err
}
rawCert, err := os.ReadFile(*pf.path)
var claims ioClaims
if err := reserveInputs(&claims, "path", *pf.path); err != nil {
return err
}
if err := reserveOutputs(&claims, "out-qr", *pf.outQRPath); err != nil {
return err
}
rawCert, err := readInput("path", *pf.path, &claims)
if err != nil {
return fmt.Errorf("unable to read cert; %s", err)
}
// When the QR is going to stdout, suppress the human-readable text/json
// output so the binary stream is not contaminated.
qrToStdout := isStdio(*pf.outQRPath)
var c cert.Certificate
var qrBytes []byte
part := 0
@@ -57,11 +69,13 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while unmarshaling cert: %s", err)
}
if *pf.json {
jsonCerts = append(jsonCerts, c)
} else {
_, _ = out.Write([]byte(c.String()))
_, _ = out.Write([]byte("\n"))
if !qrToStdout {
if *pf.json {
jsonCerts = append(jsonCerts, c)
} else {
_, _ = out.Write([]byte(c.String()))
_, _ = out.Write([]byte("\n"))
}
}
if *pf.outQRPath != "" {
@@ -79,7 +93,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
part++
}
if *pf.json {
if *pf.json && !qrToStdout {
b, _ := json.Marshal(jsonCerts)
_, _ = out.Write(b)
_, _ = out.Write([]byte("\n"))
@@ -91,7 +105,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while generating qr code: %s", err)
}
err = os.WriteFile(*pf.outQRPath, b, 0600)
err = writeOutput(*pf.outQRPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
@@ -107,6 +121,7 @@ func printSummary() string {
func printHelp(out io.Writer) {
pf := newPrintFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n"))
out.Write([]byte(stdioHelpText))
pf.set.SetOutput(out)
pf.set.PrintDefaults()
}
+39
View File
@@ -25,6 +25,7 @@ func Test_printHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -json\n"+
" \tOptional: outputs certificates in json format\n"+
" -out-qr string\n"+
@@ -178,6 +179,44 @@ func Test_printCert(t *testing.T) {
ob.String(),
)
assert.Empty(t, eb.String())
// read cert from stdin
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-json", "-path", "-"}, ob, eb)
require.NoError(t, err)
assert.Equal(
t,
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(),
)
assert.Empty(t, eb.String())
// -out-qr - sends only the PNG to stdout, suppressing the cert dump
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-path", "-", "-out-qr", "-"}, ob, eb)
require.NoError(t, err)
assert.Empty(t, eb.String())
stdout := ob.Bytes()
require.NotEmpty(t, stdout)
// PNG magic, no PEM/JSON noise prepended
assert.Equal(t, []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}, stdout[:8])
assert.NotContains(t, string(stdout), "NebulaCertificate")
assert.NotContains(t, string(stdout), `"details"`)
// json + out-qr - still suppresses json
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-json", "-path", "-", "-out-qr", "-"}, ob, eb)
require.NoError(t, err)
assert.Empty(t, eb.String())
assert.Equal(t, []byte{0x89, 'P', 'N', 'G'}, ob.Bytes()[:4])
assert.NotContains(t, ob.String(), `"details"`)
}
// NewTestCaCert will generate a CA cert
+42 -20
View File
@@ -85,6 +85,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set both -in-pub and -out-key")
}
if isP11 && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
var v4Networks []netip.Prefix
var v6Networks []netip.Prefix
@@ -102,13 +105,35 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
if *sf.outCertPath == "" {
*sf.outCertPath = *sf.name + ".crt"
}
var claims ioClaims
if err := reserveInputs(&claims,
"ca-key", *sf.caKeyPath,
"ca-crt", *sf.caCertPath,
"in-pub", *sf.inPubPath,
); err != nil {
return err
}
if err := reserveOutputs(&claims,
"out-key", *sf.outKeyPath,
"out-crt", *sf.outCertPath,
"out-qr", *sf.outQRPath,
); err != nil {
return err
}
var curve cert.Curve
var caKey []byte
if !isP11 {
var rawCAKey []byte
rawCAKey, err := os.ReadFile(*sf.caKeyPath)
rawCAKey, err = readInput("ca-key", *sf.caKeyPath, &claims)
if err != nil {
return fmt.Errorf("error while reading ca-key: %s", err)
}
@@ -121,7 +146,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if len(passphrase) == 0 {
// ask for a passphrase until we get one
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
errOut.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if errors.Is(err, ErrNoTerminal) {
@@ -147,7 +172,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
}
rawCACert, err := os.ReadFile(*sf.caCertPath)
rawCACert, err := readInput("ca-crt", *sf.caCertPath, &claims)
if err != nil {
return fmt.Errorf("error while reading ca-crt: %s", err)
}
@@ -245,7 +270,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if *sf.inPubPath != "" {
var pubCurve cert.Curve
rawPub, err := os.ReadFile(*sf.inPubPath)
rawPub, err := readInput("in-pub", *sf.inPubPath, &claims)
if err != nil {
return fmt.Errorf("error while reading in-pub: %s", err)
}
@@ -266,16 +291,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
pub, rawPriv = newKeypair(curve)
}
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
if *sf.outCertPath == "" {
*sf.outCertPath = *sf.name + ".crt"
}
if _, err := os.Stat(*sf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
if !isStdio(*sf.outCertPath) {
if _, err := os.Stat(*sf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
}
}
var crts []cert.Certificate
@@ -360,11 +379,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
if !isP11 && *sf.inPubPath == "" {
if _, err := os.Stat(*sf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
if !isStdio(*sf.outKeyPath) {
if _, err := os.Stat(*sf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
}
}
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
err = writeOutput(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
@@ -379,7 +400,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
b = append(b, sb...)
}
err = os.WriteFile(*sf.outCertPath, b, 0600)
err = writeOutput(*sf.outCertPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err)
}
@@ -390,7 +411,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("error while generating qr code: %s", err)
}
err = os.WriteFile(*sf.outQRPath, b, 0600)
err = writeOutput(*sf.outQRPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
@@ -440,6 +461,7 @@ func signSummary() string {
func signHelp(out io.Writer) {
sf := newSignFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n"))
out.Write([]byte(stdioHelpText))
sf.set.SetOutput(out)
sf.set.PrintDefaults()
}
+112 -8
View File
@@ -27,6 +27,7 @@ func Test_signHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" sign <flags>: create and sign a certificate\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -ca-crt string\n"+
" \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+
" -ca-key string\n"+
@@ -376,15 +377,18 @@ func Test_signCert(t *testing.T) {
// test with the proper password
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: ", eb.String())
// test with the proper password in the environment
os.Remove(crtF.Name())
os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
ob.Reset()
eb.Reset()
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
@@ -395,8 +399,8 @@ func Test_signCert(t *testing.T) {
testpw.password = []byte("invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: ", eb.String())
// test with the wrong password in environment
ob.Reset()
@@ -416,8 +420,8 @@ func Test_signCert(t *testing.T) {
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", eb.String())
// test an error condition
ob.Reset()
@@ -425,6 +429,106 @@ func Test_signCert(t *testing.T) {
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: ", eb.String())
}
func Test_signCert_stdio(t *testing.T) {
nopw := &StubPasswordReader{
password: []byte(""),
err: nil,
}
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
rawCAKey := cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)
ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
rawCACrt, _ := ca.MarshalPEM()
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
require.NoError(t, err)
defer os.Remove(caCrtF.Name())
caCrtF.Write(rawCACrt)
caKeyF, err := os.CreateTemp("", "sign-cert.key")
require.NoError(t, err)
defer os.Remove(caKeyF.Name())
caKeyF.Write(rawCAKey)
keyF, err := os.CreateTemp("", "sign.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
// ca-key on stdin, cert to stdout
withStdin(t, bytes.NewReader(rawCAKey))
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
args := []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", keyF.Name(), "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
lCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, "stdin-test", lCrt.Name())
assert.True(t, lCrt.CheckSignature(caPub))
// two flags reading from stdin should error before any read attempt;
// otherwise an interactive shell would hang on io.ReadAll
stdinIn := bytes.NewReader(rawCAKey)
withStdin(t, stdinIn)
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", "-", "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw),
`-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
assert.Equal(t, len(rawCAKey), stdinIn.Len(), "stdin should be untouched when conflict is caught up front")
// two flags writing to stdout should error before any output is written
// AND before stdin is consumed
stdinR := bytes.NewReader(rawCAKey)
withStdin(t, stdinR)
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", "-", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
// stdin should be untouched because the conflict was caught up front
assert.Equal(t, len(rawCAKey), stdinR.Len())
// out-key on stdout, cert on disk
keyF2, err := os.CreateTemp("", "sign.key")
require.NoError(t, err)
os.Remove(keyF2.Name())
defer os.Remove(keyF2.Name())
crtF, err := os.CreateTemp("", "sign.crt")
require.NoError(t, err)
os.Remove(crtF.Name())
defer os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", "-", "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
_, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
// in-pub on stdin (caller already has a keypair, only the cert is generated)
inPub, _ := x25519Keypair()
rawInPub := cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)
withStdin(t, bytes.NewReader(rawInPub))
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "in-pub-test", "-ip", "1.1.1.1/24", "-in-pub", "-", "-out-crt", "-", "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
stdinCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, "in-pub-test", stdinCrt.Name())
assert.Equal(t, inPub, stdinCrt.PublicKey())
}
+117
View File
@@ -0,0 +1,117 @@
package main
import (
"fmt"
"io"
"os"
)
// stdioPath is the special path value that selects stdin (for inputs) or
// stdout (for outputs) instead of a file on disk.
const stdioPath = "-"
// stdioHelpText is rendered just under the Usage line of each subcommand
// help so the - convention is documented once instead of on every flag.
const stdioHelpText = " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"
// stdinReader is the source used when an input flag is set to "-".
// It is a package level var so tests can swap in a deterministic reader.
// Tests that mutate stdinReader cannot run with t.Parallel().
var stdinReader io.Reader = os.Stdin
// ioClaims tracks which flags have claimed stdin and stdout during a single
// command invocation so we can refuse a second flag asking for the same
// stream.
type ioClaims struct {
in string
out string
}
func (c *ioClaims) claimIn(flagName string) error {
if c.in != "" && c.in != flagName {
return fmt.Errorf("-%s and -%s both set to %q, only one input may read from stdin", c.in, flagName, stdioPath)
}
c.in = flagName
return nil
}
func (c *ioClaims) claimOut(flagName string) error {
if c.out != "" && c.out != flagName {
return fmt.Errorf("-%s and -%s both set to %q, only one output may write to stdout", c.out, flagName, stdioPath)
}
c.out = flagName
return nil
}
// reserveInputs walks alternating (flagName, path) pairs and claims stdin
// for any path equal to stdioPath. It must be called before any input is
// read so a conflict can be reported immediately instead of blocking on
// io.ReadAll while waiting for input that will never arrive.
func reserveInputs(claims *ioClaims, pairs ...string) error {
return reserveStdio(claims, "reserveInputs", (*ioClaims).claimIn, pairs)
}
// reserveOutputs walks alternating (flagName, path) pairs and claims stdout
// for any path equal to stdioPath. It must be called before any output is
// written so a conflict cannot leave one stream half written before the
// second flag fails.
func reserveOutputs(claims *ioClaims, pairs ...string) error {
return reserveStdio(claims, "reserveOutputs", (*ioClaims).claimOut, pairs)
}
func reserveStdio(claims *ioClaims, who string, claim func(*ioClaims, string) error, pairs []string) error {
if len(pairs)%2 != 0 {
panic(who + " requires alternating name, path pairs")
}
for i := 0; i < len(pairs); i += 2 {
name, path := pairs[i], pairs[i+1]
if path != stdioPath {
continue
}
if err := claim(claims, name); err != nil {
return err
}
}
return nil
}
// readInput returns the bytes referenced by path, reading from stdin when
// path is stdioPath.
func readInput(flagName, path string, claims *ioClaims) ([]byte, error) {
if path == stdioPath {
if err := claims.claimIn(flagName); err != nil {
return nil, err
}
return io.ReadAll(stdinReader)
}
return os.ReadFile(path)
}
// openInput returns a reader for path. When path is stdioPath the returned
// reader wraps stdin and Close is a no-op.
func openInput(flagName, path string, claims *ioClaims) (io.ReadCloser, error) {
if path == stdioPath {
if err := claims.claimIn(flagName); err != nil {
return nil, err
}
return io.NopCloser(stdinReader), nil
}
return os.Open(path)
}
// writeOutput writes data to path, or to stdout when path is stdioPath. perm
// is only used for file output. The caller must have already claimed stdout
// via reserveOutputs before invoking with stdioPath.
func writeOutput(path string, data []byte, perm os.FileMode, stdout io.Writer) error {
if path == stdioPath {
_, err := stdout.Write(data)
return err
}
return os.WriteFile(path, data, perm)
}
// isStdio reports whether path is the stdio sentinel and so should skip
// existence checks like "refuse to overwrite".
func isStdio(path string) bool {
return path == stdioPath
}
+167
View File
@@ -0,0 +1,167 @@
package main
import (
"bytes"
"io"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// withStdin temporarily replaces stdinReader for the duration of t.
func withStdin(t *testing.T, r io.Reader) {
t.Helper()
prev := stdinReader
stdinReader = r
t.Cleanup(func() { stdinReader = prev })
}
func Test_readInput_stdin(t *testing.T) {
withStdin(t, bytes.NewBufferString("hello"))
var claims ioClaims
got, err := readInput("path", "-", &claims)
require.NoError(t, err)
assert.Equal(t, []byte("hello"), got)
assert.Equal(t, "path", claims.in)
}
func Test_readInput_file(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
require.NoError(t, os.WriteFile(p, []byte("file"), 0600))
var claims ioClaims
got, err := readInput("path", p, &claims)
require.NoError(t, err)
assert.Equal(t, []byte("file"), got)
assert.Empty(t, claims.in)
}
func Test_readInput_doubleStdinErrors(t *testing.T) {
withStdin(t, bytes.NewBufferString("hello"))
var claims ioClaims
_, err := readInput("ca-key", "-", &claims)
require.NoError(t, err)
_, err = readInput("ca-crt", "-", &claims)
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
}
func Test_openInput_stdin(t *testing.T) {
withStdin(t, bytes.NewBufferString("hi"))
var claims ioClaims
r, err := openInput("ca", "-", &claims)
require.NoError(t, err)
defer r.Close()
b, err := io.ReadAll(r)
require.NoError(t, err)
assert.Equal(t, []byte("hi"), b)
}
func Test_openInput_doubleStdinErrors(t *testing.T) {
withStdin(t, bytes.NewBufferString("hi"))
var claims ioClaims
r, err := openInput("ca", "-", &claims)
require.NoError(t, err)
r.Close()
_, err = openInput("crt", "-", &claims)
require.EqualError(t, err, `-ca and -crt both set to "-", only one input may read from stdin`)
}
func Test_writeOutput_stdout(t *testing.T) {
out := &bytes.Buffer{}
err := writeOutput("-", []byte("payload"), 0600, out)
require.NoError(t, err)
assert.Equal(t, "payload", out.String())
}
func Test_writeOutput_file(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
out := &bytes.Buffer{}
err := writeOutput(p, []byte("payload"), 0600, out)
require.NoError(t, err)
assert.Empty(t, out.String())
got, err := os.ReadFile(p)
require.NoError(t, err)
assert.Equal(t, []byte("payload"), got)
}
func Test_reserveOutputs_noConflict(t *testing.T) {
var claims ioClaims
require.NoError(t, reserveOutputs(&claims,
"out-key", "/tmp/key",
"out-crt", "-",
"out-qr", "",
))
assert.Equal(t, "out-crt", claims.out)
}
func Test_reserveOutputs_conflict(t *testing.T) {
var claims ioClaims
err := reserveOutputs(&claims,
"out-key", "-",
"out-crt", "-",
)
require.EqualError(t, err, `-out-key and -out-crt both set to "-", only one output may write to stdout`)
}
func Test_reserveOutputs_panicsOnOddPairs(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
}()
var claims ioClaims
_ = reserveOutputs(&claims, "out-key")
}
func Test_reserveInputs_noConflict(t *testing.T) {
var claims ioClaims
require.NoError(t, reserveInputs(&claims,
"ca-key", "/tmp/ca.key",
"ca-crt", "-",
"in-pub", "",
))
assert.Equal(t, "ca-crt", claims.in)
}
func Test_reserveInputs_conflict(t *testing.T) {
var claims ioClaims
err := reserveInputs(&claims,
"ca-key", "-",
"ca-crt", "-",
)
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
}
func Test_claimIn_idempotent(t *testing.T) {
// pre-claim then a lazy re-claim of the same flag should be a no-op
var claims ioClaims
require.NoError(t, claims.claimIn("ca-key"))
require.NoError(t, claims.claimIn("ca-key"))
assert.Equal(t, "ca-key", claims.in)
}
func Test_claimOut_idempotent(t *testing.T) {
var claims ioClaims
require.NoError(t, claims.claimOut("out-crt"))
require.NoError(t, claims.claimOut("out-crt"))
assert.Equal(t, "out-crt", claims.out)
}
func Test_isStdio(t *testing.T) {
assert.True(t, isStdio("-"))
assert.False(t, isStdio(""))
assert.False(t, isStdio("./-"))
assert.False(t, isStdio("foo"))
}
+13 -4
View File
@@ -39,18 +39,26 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return err
}
caFile, err := os.Open(*vf.caPath)
var claims ioClaims
if err := reserveInputs(&claims,
"ca", *vf.caPath,
"crt", *vf.certPath,
); err != nil {
return err
}
caReader, err := openInput("ca", *vf.caPath, &claims)
if err != nil {
return fmt.Errorf("error while reading ca: %w", err)
}
defer caFile.Close()
defer caReader.Close()
caPool, err := cert.NewCAPoolFromPEMReader(caFile)
caPool, err := cert.NewCAPoolFromPEMReader(caReader)
if err != nil && !errors.Is(err, cert.ErrExpired) {
return fmt.Errorf("error while adding ca cert to pool: %w", err)
}
rawCert, err := os.ReadFile(*vf.certPath)
rawCert, err := readInput("crt", *vf.certPath, &claims)
if err != nil {
return fmt.Errorf("unable to read crt: %w", err)
}
@@ -85,6 +93,7 @@ func verifySummary() string {
func verifyHelp(out io.Writer) {
vf := newVerifyFlags()
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
_, _ = out.Write([]byte(stdioHelpText))
vf.set.SetOutput(out)
vf.set.PrintDefaults()
}
+44
View File
@@ -23,6 +23,7 @@ func Test_verifyHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority.\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -ca string\n"+
" \tRequired: path to a file containing one or more ca certificates\n"+
" -crt string\n"+
@@ -122,3 +123,46 @@ func Test_verify(t *testing.T) {
assert.Empty(t, eb.String())
require.NoError(t, err)
}
func Test_verify_stdio(t *testing.T) {
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
caPEM, _ := ca.MarshalPEM()
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
crtPEM, _ := crt.MarshalPEM()
caFile, err := os.CreateTemp("", "verify-ca")
require.NoError(t, err)
defer os.Remove(caFile.Name())
caFile.Write(caPEM)
// crt on stdin, ca on disk
withStdin(t, bytes.NewReader(crtPEM))
require.NoError(t, verify([]string{"-ca", caFile.Name(), "-crt", "-"}, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// ca on stdin, crt on disk
certFile, err := os.CreateTemp("", "verify-cert")
require.NoError(t, err)
defer os.Remove(certFile.Name())
certFile.Write(crtPEM)
withStdin(t, bytes.NewReader(caPEM))
ob.Reset()
eb.Reset()
require.NoError(t, verify([]string{"-ca", "-", "-crt", certFile.Name()}, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// both flags on stdin should error
withStdin(t, bytes.NewReader(caPEM))
ob.Reset()
eb.Reset()
require.EqualError(t, verify([]string{"-ca", "-", "-crt", "-"}, ob, eb),
`-ca and -crt both set to "-", only one input may read from stdin`)
}
+6 -3
View File
@@ -61,9 +61,12 @@ func main() {
}
if *configPath == "" {
fmt.Println("-config flag must be set")
flag.Usage()
os.Exit(1)
p, err := config.DefaultPath()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
*configPath = p
}
c := config.NewC(l)
+2 -15
View File
@@ -3,8 +3,6 @@ package main
import (
"fmt"
"log"
"os"
"path/filepath"
"github.com/kardianos/service"
"github.com/slackhq/nebula"
@@ -57,24 +55,13 @@ func (p *program) Stop(s service.Service) error {
return nil
}
func fileExists(filename string) bool {
_, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return true
}
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
if *configPath == "" {
ex, err := os.Executable()
p, err := config.DefaultPath()
if err != nil {
return err
}
*configPath = filepath.Dir(ex) + "/config.yaml"
if !fileExists(*configPath) {
*configPath = filepath.Dir(ex) + "/config.yml"
}
*configPath = p
}
svcConfig := &service.Config{
+6 -3
View File
@@ -50,9 +50,12 @@ func main() {
}
if *configPath == "" {
fmt.Println("-config flag must be set")
flag.Usage()
os.Exit(1)
p, err := config.DefaultPath()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
*configPath = p
}
l := logging.NewLogger(os.Stdout)
+29
View File
@@ -0,0 +1,29 @@
package config
import (
"fmt"
"os"
"path/filepath"
)
// DefaultPath returns a path to a config file alongside the running executable, preferring config.yaml over config.yml.
// If neither file exists an error is returned that names both paths checked.
func DefaultPath() (string, error) {
ex, err := os.Executable()
if err != nil {
return "", err
}
return defaultPathInDir(filepath.Dir(ex))
}
func defaultPathInDir(dir string) (string, error) {
yamlPath := filepath.Join(dir, "config.yaml")
if _, err := os.Stat(yamlPath); err == nil {
return yamlPath, nil
}
ymlPath := filepath.Join(dir, "config.yml")
if _, err := os.Stat(ymlPath); err == nil {
return ymlPath, nil
}
return "", fmt.Errorf("no default config found at %s or %s", yamlPath, ymlPath)
}
+67
View File
@@ -0,0 +1,67 @@
package config
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDefaultPathInDir(t *testing.T) {
t.Run("prefers config.yaml when both exist", func(t *testing.T) {
dir := t.TempDir()
want := filepath.Join(dir, "config.yaml")
other := filepath.Join(dir, "config.yml")
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
require.NoError(t, os.WriteFile(other, []byte("a: 2"), 0644))
got, err := defaultPathInDir(dir)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("returns config.yaml when only it exists", func(t *testing.T) {
dir := t.TempDir()
want := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
got, err := defaultPathInDir(dir)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("falls back to config.yml when only it exists", func(t *testing.T) {
dir := t.TempDir()
want := filepath.Join(dir, "config.yml")
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
got, err := defaultPathInDir(dir)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("errors when neither exists and names both paths", func(t *testing.T) {
dir := t.TempDir()
got, err := defaultPathInDir(dir)
assert.Empty(t, got)
require.Error(t, err)
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yaml"))
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yml"))
})
}
func TestDefaultPath(t *testing.T) {
got, err := DefaultPath()
if err != nil {
ex, exErr := os.Executable()
require.NoError(t, exErr)
assert.Contains(t, err.Error(), filepath.Dir(ex))
return
}
ex, err := os.Executable()
require.NoError(t, err)
assert.Equal(t, filepath.Dir(ex), filepath.Dir(got))
assert.Contains(t, []string{"config.yaml", "config.yml"}, filepath.Base(got))
}
+12 -42
View File
@@ -11,7 +11,6 @@ import (
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
@@ -45,19 +44,16 @@ type connectionManager struct {
inactivityTimeout atomic.Int64
dropInactive atomic.Bool
metricsTxPunchy metrics.Counter
l *slog.Logger
}
func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
cm := &connectionManager{
hostMap: hm,
l: l,
punchy: p,
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
hostMap: hm,
l: l,
punchy: p,
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
}
cm.reload(c, true)
@@ -369,7 +365,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if !outTraffic {
// Send a punch packet to keep the NAT state alive
cm.sendPunch(hostinfo)
cm.punchy.SendPunch(hostinfo)
}
return decision, hostinfo, primary
@@ -400,17 +396,16 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so.
cm.sendPunch(hostinfo)
cm.punchy.SendPunch(hostinfo)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
return doNothing, nil, nil
}
if cm.punchy.GetTargetEverything() {
// This is similar to the old punchy behavior with a slight optimization.
// We aren't receiving traffic but we are sending it, punch on all known
// ips in case we need to re-prime NAT state
cm.sendPunch(hostinfo)
}
// We aren't receiving traffic but we are sending it. The outbound
// traffic itself refreshes the primary remote's NAT state; this
// fans out to non-primary remotes, but only if target_all_remotes
// is configured.
cm.punchy.SendPunchToAll(hostinfo)
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Tunnel status",
@@ -512,31 +507,6 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
}
}
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
if !cm.punchy.GetPunch() {
// Punching is disabled
return
}
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
// would lose the ability to notify us and punchy.respond would become unreliable.
return
}
if cm.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, addr)
})
} else if hostinfo.remote.IsValid() {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
+4 -4
View File
@@ -64,7 +64,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Create manager
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
@@ -146,7 +146,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Create manager
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
@@ -233,7 +233,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
@@ -358,7 +358,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
// Create manager
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
ifce.connectionManager = nc
+5 -4
View File
@@ -7,13 +7,14 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/handshake"
"github.com/slackhq/nebula/noiseutil"
)
const ReplayWindow = 1024
type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
eKey noiseutil.CipherState
dKey noiseutil.CipherState
myCert cert.Certificate
peerCert *cert.CachedCertificate
initiator bool
@@ -31,8 +32,8 @@ func newConnectionStateFromResult(r *handshake.Result) *ConnectionState {
myCert: r.MyCert,
initiator: r.Initiator,
peerCert: r.RemoteCert,
eKey: NewNebulaCipherState(r.EKey),
dKey: NewNebulaCipherState(r.DKey),
eKey: noiseutil.NewCipherState(r.EKey, r.Cipher),
dKey: noiseutil.NewCipherState(r.DKey, r.Cipher),
window: NewBits(ReplayWindow),
}
ci.messageCounter.Add(r.MessageIndex)
+92 -26
View File
@@ -11,19 +11,21 @@ import (
"sync"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/slackhq/nebula/config"
)
type dnsServer struct {
sync.RWMutex
l *slog.Logger
ctx context.Context
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Lite
l *slog.Logger
ctx context.Context
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
pki *PKI
// selfHost is the cached FQDN we last seeded for ourselves
selfHost string
mux *dns.ServeMux
@@ -55,14 +57,14 @@ type dnsServer struct {
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
// watcher that tears the listener down on nebula shutdown. The returned
// pointer is always non-nil, even on error.
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*dnsServer, error) {
ds := &dnsServer{
l: l,
ctx: ctx,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
l: l,
ctx: ctx,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
pki: pki,
}
ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest)
@@ -76,6 +78,7 @@ func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState,
if err := ds.reload(c, true); err != nil {
return ds, err
}
ds.seedSelf()
return ds, nil
}
@@ -113,7 +116,7 @@ func (d *dnsServer) reload(c *config.C, initial bool) error {
d.Stop()
}
// Drop any records that accumulated while enabled; a later re-enable
// will repopulate from fresh handshakes.
// will repopulate from fresh handshakes and a fresh seedSelf.
d.clearRecords()
return nil
}
@@ -121,17 +124,14 @@ func (d *dnsServer) reload(c *config.C, initial bool) error {
if running == nil {
// Was disabled (or never started); bring it up now.
go d.Start()
return nil
} else if !sameAddr {
d.shutdownServer(running, runningStarted, "reload")
// Old Start goroutine has now exited; bring up a fresh listener on the new address.
go d.Start()
}
if sameAddr {
return nil
}
d.shutdownServer(running, runningStarted, "reload")
// Old Start goroutine has now exited; bring up a fresh listener on the
// new address.
go d.Start()
// Refresh the self entry every enabled reload so cert renewals that change our name or VPN addresses are picked up.
d.seedSelf()
return nil
}
@@ -249,6 +249,20 @@ func (d *dnsServer) QueryCert(data string) string {
return ""
}
// The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves.
// Answer self lookups straight from the local cert state.
if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) {
c := cs.GetDefaultCertificate()
if c == nil {
return ""
}
b, err := c.MarshalJSON()
if err != nil {
return ""
}
return string(b)
}
hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil {
return ""
@@ -266,12 +280,60 @@ func (d *dnsServer) QueryCert(data string) string {
return string(b)
}
// clearRecords drops all DNS records.
// clearRecords drops all DNS records, including the self entry.
func (d *dnsServer) clearRecords() {
d.Lock()
defer d.Unlock()
clear(d.dnsMap4)
clear(d.dnsMap6)
d.selfHost = ""
}
// seedSelf inserts (or refreshes) a record for our own cert name pointing at our VPN addresses,
// so a single-lighthouse network can resolve the lighthouse's own hostname without the two-process workaround.
func (d *dnsServer) seedSelf() {
if !d.enabled.Load() {
return
}
cs := d.certState()
if cs == nil {
return
}
c := cs.GetDefaultCertificate()
if c == nil {
return
}
newHost := strings.ToLower(c.Name()) + "."
d.Lock()
defer d.Unlock()
if d.selfHost != "" && d.selfHost != newHost {
delete(d.dnsMap4, d.selfHost)
delete(d.dnsMap6, d.selfHost)
}
d.selfHost = newHost
delete(d.dnsMap4, newHost)
delete(d.dnsMap6, newHost)
haveV4, haveV6 := false, false
for _, addr := range cs.myVpnAddrs {
if addr.Is4() && !haveV4 {
d.dnsMap4[newHost] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[newHost] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
}
func (d *dnsServer) certState() *CertState {
if d.pki == nil {
return nil
}
return d.pki.getCertState()
}
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
@@ -309,8 +371,12 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
return true
}
cs := d.certState()
if cs == nil || cs.myVpnAddrsTable == nil {
return false
}
//if we found it in this table, it's good
return d.myVpnAddrsTable.Contains(b)
return cs.myVpnAddrsTable.Contains(b)
}
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
+89
View File
@@ -9,7 +9,10 @@ import (
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -276,6 +279,92 @@ func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
}
}
// newTestPKI builds a minimal *PKI with a single v1 cert whose name and
// VPN addresses are caller-provided, suitable for exercising seedSelf and
// QueryCert self handling.
func newTestPKI(t *testing.T, name string, addrs []netip.Addr) *PKI {
t.Helper()
networks := make([]netip.Prefix, 0, len(addrs))
for _, a := range addrs {
bits := 32
if a.Is6() {
bits = 128
}
networks = append(networks, netip.PrefixFrom(a, bits))
}
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
c, _, _, _ := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, ca, caKey, name, time.Time{}, time.Time{}, networks, nil, nil)
addrsTable := new(bart.Lite)
for _, a := range addrs {
addrsTable.Insert(netip.PrefixFrom(a, a.BitLen()))
}
cs := &CertState{
v2Cert: c,
initiatingVersion: cert.Version2,
myVpnAddrs: addrs,
myVpnAddrsTable: addrsTable,
}
pki := &PKI{}
pki.cs.Store(cs)
return pki
}
func TestDnsServer_seedSelf_addsOwnRecord(t *testing.T) {
ds, c := newTestDnsServer(t)
myV4 := netip.MustParseAddr("10.0.0.1")
myV6 := netip.MustParseAddr("fd00::1")
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4, myV6})
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
ds.seedSelf()
got4, exists := ds.Query(dns.TypeA, "lighthouse.")
assert.True(t, exists)
assert.Equal(t, myV4, got4)
got6, exists := ds.Query(dns.TypeAAAA, "lighthouse.")
assert.True(t, exists)
assert.Equal(t, myV6, got6)
}
func TestDnsServer_seedSelf_disabled_noOp(t *testing.T) {
ds, c := newTestDnsServer(t)
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")})
setDnsConfig(c, "127.0.0.1", "0", true, false)
require.NoError(t, ds.reload(c, true))
ds.seedSelf()
_, exists := ds.Query(dns.TypeA, "lighthouse.")
assert.False(t, exists)
}
func TestDnsServer_clearRecords_dropsSelfHost(t *testing.T) {
ds, c := newTestDnsServer(t)
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")})
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
ds.seedSelf()
require.NotEmpty(t, ds.selfHost)
ds.clearRecords()
assert.Empty(t, ds.selfHost)
_, exists := ds.Query(dns.TypeA, "lighthouse.")
assert.False(t, exists)
}
func TestDnsServer_QueryCert_returnsOwnCert(t *testing.T) {
ds, _ := newTestDnsServer(t)
myV4 := netip.MustParseAddr("10.0.0.1")
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4})
got := ds.QueryCert(myV4.String() + ".")
assert.NotEmpty(t, got, "TXT lookup of our own VPN address should return our cert")
other := netip.MustParseAddr("10.0.0.99")
assert.Empty(t, ds.QueryCert(other.String()+"."), "unknown peer IP should return nothing")
}
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
port := freeUDPPort(t)
ds, c := newTestDnsServer(t)
+3 -7
View File
@@ -18,14 +18,10 @@ import (
// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain
// before failing the assertion.
//
// IgnoreCurrent is necessary in the parallelized suite: other tests can
// leave goroutines mid-shutdown when this one runs (Stop is async, the
// wg.Wait() drain is not blocking on test return). We're checking that
// *this* test's setup tears down cleanly, not that the whole suite is
// idle at this moment. Intentionally NOT t.Parallel()'d for the same
// reason — concurrent test goroutines would always show up.
// Intentionally NOT t.Parallel()'d: concurrent tests would have their own
// goroutines running and trip the assertion.
func TestNoGoroutineLeaks(t *testing.T) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
defer goleak.VerifyNone(t)
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
+125
View File
@@ -0,0 +1,125 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"net"
"strings"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
func TestSSHDLifecycle(t *testing.T) {
// TestSSHDLifecycle exercises the in-process sshd through several config reloads and a Control.Stop.
ca, _, caKey, _ := cert_test.NewTestCaCert(
cert.Version1, cert.Curve_CURVE25519,
time.Now(), time.Now().Add(10*time.Minute),
nil, nil, []string{},
)
hostKeyPEM := generateSSHHostKey(t)
clientSigner, clientAuthKey := generateSSHClientKey(t)
sshdAddr := allocLoopbackPort(t)
overrides := m{
"sshd": m{
"enabled": true,
"listen": sshdAddr,
"host_key": hostKeyPEM,
"authorized_users": []m{{
"user": "tester",
"keys": []string{clientAuthKey},
}},
},
}
control, _, _, _ := newSimpleServer(cert.Version1, ca, caKey, "sshd-test", "10.222.0.1/24", overrides)
control.Start()
t.Cleanup(func() { control.Stop() })
// sshd binds in a goroutine after Start returns; wait for it.
require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
"sshd never started listening")
for i := 1; i <= 3; i++ {
out := sshExecReload(t, sshdAddr, clientSigner)
assert.Contains(t, out, "Reloading config", "reload cycle %d", i)
require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
"sshd not listening after reload cycle %d", i)
}
control.Stop()
require.Eventually(t, func() bool { return !canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
"sshd still listening after Control.Stop")
}
func canDial(addr string) bool {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err != nil {
return false
}
_ = c.Close()
return true
}
// allocLoopbackPort grabs an unused TCP port on 127.0.0.1, closes it, and returns the address. There
// is a small race between releasing the port and the sshd reclaiming it; in practice the OS keeps the
// port available long enough for the test to bind it.
func allocLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func generateSSHHostKey(t *testing.T) string {
t.Helper()
_, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
block, err := ssh.MarshalPrivateKey(priv, "nebula-e2e-host")
require.NoError(t, err)
return string(pem.EncodeToMemory(block))
}
func generateSSHClientKey(t *testing.T) (ssh.Signer, string) {
t.Helper()
_, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
signer, err := ssh.NewSignerFromKey(priv)
require.NoError(t, err)
auth := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey())))
return signer, auth
}
func sshExecReload(t *testing.T, addr string, signer ssh.Signer) string {
t.Helper()
cfg := &ssh.ClientConfig{
User: "tester",
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
client, err := ssh.Dial("tcp", addr, cfg)
require.NoError(t, err)
defer client.Close()
sess, err := client.NewSession()
require.NoError(t, err)
defer sess.Close()
// reload tears the channel down before sending exit-status, so Output returns an error on the
// channel close. The output buffer still has whatever the reload callback wrote before that.
out, _ := sess.Output("reload")
return string(out)
}
+30
View File
@@ -138,6 +138,14 @@ listen:
# max, net.core.rmem_max and net.core.wmem_max
#read_buffer: 10485760
#write_buffer: 10485760
# On Windows only
# When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to UDP at the listener port.
# WFP sits below Windows Defender Firewall, so this lets peer handshakes reach Nebula's outside socket regardless
# of WDF's inbound rules.
# Default true; set to false to leave WDF in charge of inbound decisions on the listener port. Not reloadable.
#windows_bypass_wdf: true
# By default, Nebula replies to packets it has no tunnel for with a "recv_error" packet. This packet helps speed up reconnection
# in the case that Nebula on either side did not shut down cleanly. This response can be abused as a way to discover if Nebula is running
# on a host though. This option lets you configure if you want to send "recv_error" packets always, never, or only to private network remotes.
@@ -163,17 +171,21 @@ listen:
punchy:
# Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
# This setting is reloadable.
punch: true
# respond means that a node you are trying to reach will connect back out to you if your hole punching fails
# this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT
# Default is false
# This setting is reloadable.
#respond: true
# delays a punch response for misbehaving NATs, default is 1 second.
# This setting is reloadable.
#delay: 1s
# set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect.
# This setting is reloadable.
#respond_delay: 5s
# Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes
@@ -282,6 +294,24 @@ tun:
# metric: 100
# install: true
# On Windows only, sets the network category of the nebula interface. Without this, Windows often
# leaves the network as "Unidentified" and treats it as Public, which makes the host firewall more
# restrictive than you usually want for an overlay between trusted peers. Valid values:
# private - treat the nebula network as a private/trusted network (default)
# public - treat it as a public/untrusted network
# domain - treat it as a domain-authenticated network
# unset - leave whatever Windows decided alone
# Not reloadable.
#network_category: private
# On Windows only
# When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to the nebula adapter LUID.
# WFP sits below Windows Defender Firewall, so this lets inbound traffic through regardless of WDF rules.
# Filters are auto-removed when the adapter goes away.
# See listen.windows_bypass_wdf for the matching control over inbound to nebula's outside UDP listener.
# Default true; set to false to leave WDF in charge of inbound decisions on the nebula interface. Not reloadable.
#windows_bypass_wdf: true
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
# in nebula configuration files. Default false, not reloadable.
#use_system_route_table: false
+43 -26
View File
@@ -58,8 +58,9 @@ type Firewall struct {
routableNetworks *bart.Lite
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
hasUnsafeNetworks bool
assignedNetworks []netip.Prefix
// unsafeNetworks is the list of unsafe networks issued to us in the certificate
unsafeNetworks []netip.Prefix
rules string
rulesVersion uint16
@@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
assignedNetworks = append(assignedNetworks, network)
}
hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
unsafeNetworks := c.UnsafeNetworks()
for _, n := range unsafeNetworks {
routableNetworks.Insert(n)
hasUnsafeNetworks = true
}
return &Firewall{
@@ -169,15 +169,15 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
},
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks,
l: l,
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
unsafeNetworks: unsafeNetworks,
l: l,
incomingMetrics: firewallMetrics{
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
@@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
}
if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny {
flc.Any = true
return nil
}
@@ -1055,7 +1055,6 @@ func (r *rule) sanity() error {
}
func parsePort(s string) (int32, int32, error) {
var err error
const notAPort int32 = -2
if s == "any" {
return firewall.PortAny, firewall.PortAny, nil
@@ -1064,11 +1063,11 @@ func parsePort(s string) (int32, int32, error) {
return firewall.PortFragment, firewall.PortFragment, nil
}
if !strings.Contains(s, `-`) {
rPort, err := strconv.Atoi(s)
rPort, err := parsePortValue("", s)
if err != nil {
return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s)
return notAPort, notAPort, err
}
return int32(rPort), int32(rPort), nil
return rPort, rPort, nil
}
sPorts := strings.SplitN(s, `-`, 2)
@@ -1079,22 +1078,40 @@ func parsePort(s string) (int32, int32, error) {
return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
}
rStartPort, err := strconv.Atoi(sPorts[0])
startPort, err := parsePortValue("beginning range ", sPorts[0])
if err != nil {
return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
return notAPort, notAPort, err
}
rEndPort, err := strconv.Atoi(sPorts[1])
endPort, err := parsePortValue("ending range ", sPorts[1])
if err != nil {
return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
return notAPort, notAPort, err
}
startPort := int32(rStartPort)
endPort := int32(rEndPort)
if startPort == firewall.PortAny {
endPort = firewall.PortAny
}
return startPort, endPort, nil
}
// parsePortValue accepts a base-10 decimal in [0, 65535] and returns it
// widened to int32. Using strconv.ParseUint with bitSize 16 rejects
// negative input, out-of-range input (>65535), and any non-decimal byte
// by construction, so the int32 widening that follows is provably safe
// and cannot collide with firewall.PortAny (0) or firewall.PortFragment
// (-1) via integer truncation.
//
// prefix is prepended to both error messages so callers can disambiguate
// the single-port path (prefix="") from the range bounds (prefix="beginning
// range " / "ending range "), preserving the historical error strings.
func parsePortValue(prefix, s string) (int32, error) {
n, err := strconv.ParseUint(s, 10, 16)
if err == nil {
return int32(n), nil
}
if errors.Is(err, strconv.ErrRange) {
return 0, fmt.Errorf("%sout of range [0,65535]; `%s`", prefix, s)
}
return 0, fmt.Errorf("%swas not a number; `%s`", prefix, s)
}
+69
View File
@@ -1029,6 +1029,75 @@ func Test_parsePort(t *testing.T) {
require.NoError(t, err)
}
// Test_parsePort_invalid covers inputs that must error. The named bug is
// that int32(strconv.Atoi("4294967296")) truncates to 0 == firewall.PortAny,
// silently turning a typo into a match-all-ports rule; the rest are
// representative syntax/range probes.
func Test_parsePort_invalid(t *testing.T) {
tests := []struct {
name string
input string
wantErrContains string
}{
// Numeric overflow (the named bug + boundary).
{"named bug: 2^32 truncates to PortAny", "4294967296", "out of range"},
{"just above max real port", "65536", "out of range"},
// Negatives route through the range branch and hit the empty-half
// guard; included as defense in depth so a future refactor cannot
// accidentally reach the int32 cast.
{"negative", "-1", "could not be parsed"},
// Syntax probes.
{"NUL between digits", "4\x002", "was not a number"},
{"hex notation", "0x10", "was not a number"},
{"scientific notation", "1e3", "was not a number"},
{"leading whitespace", " 42", "was not a number"},
{"fullwidth digits", "42", "was not a number"},
// Range branch.
{"range upper out of range", "1-65536", "ending range out of range"},
{"range lower out of range", "65536-65537", "beginning range out of range"},
{"range with negative upper", "1--1", "ending range was not a number"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, _, err := parsePort(tc.input)
require.Error(t, err, "input %q must error", tc.input)
require.ErrorContains(t, err, tc.wantErrContains)
})
}
}
// Test_parsePort_valid_boundaries locks in success cases at 0, 1, and 65535
// so a future refactor cannot regress the boundaries.
func Test_parsePort_valid_boundaries(t *testing.T) {
tests := []struct {
name string
input string
wantStart int32
wantEnd int32
}{
{"zero is PortAny", "0", 0, 0},
{"min real port", "1", 1, 1},
{"max real port", "65535", 65535, 65535},
{"range zero to max forces end to zero", "0-65535", 0, 0},
{"range max to max", "65535-65535", 65535, 65535},
{"range one to max", "1-65535", 1, 65535},
{"range with whitespace inside", " 1 - 2 ", 1, 2},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
s, e, err := parsePort(tc.input)
require.NoError(t, err)
assert.Equal(t, tc.wantStart, s, "start port")
assert.Equal(t, tc.wantEnd, e, "end port")
})
}
}
func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger()
// Test a bad rule definition
+5 -5
View File
@@ -9,7 +9,7 @@ require (
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.26.0
github.com/gaissmai/bart v0.27.1
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4
@@ -24,12 +24,12 @@ require (
github.com/vishvananda/netlink v1.3.1
go.uber.org/goleak v1.3.0
go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.50.0
golang.org/x/crypto v0.51.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.52.0
golang.org/x/net v0.54.0
golang.org/x/sync v0.20.0
golang.org/x/sys v0.43.0
golang.org/x/term v0.42.0
golang.org/x/sys v0.44.0
golang.org/x/term v0.43.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.6.1
+10 -10
View File
@@ -26,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/gaissmai/bart v0.27.1 h1:FysPzqETMJa8q9rNkLW5peT1hq25nLOz8ksHbSVoiAk=
github.com/gaissmai/bart v0.27.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+2
View File
@@ -33,6 +33,7 @@ type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
type Result struct {
EKey *noise.CipherState
DKey *noise.CipherState
Cipher noise.CipherFunc // identifies which post-handshake CipherState the data plane should wrap EKey/DKey in
MyCert cert.Certificate
RemoteCert *cert.CachedCertificate
RemoteIndex uint32
@@ -114,6 +115,7 @@ func NewMachine(
myVersion: version,
result: &Result{
Initiator: initiator,
Cipher: cred.cipherSuite,
},
multiport: multiport,
+2 -4
View File
@@ -87,6 +87,7 @@ type HandshakeHostInfo struct {
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
lastRelays []netip.Addr // Relays we attempted to use during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo
@@ -221,7 +222,6 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
fields := []any{
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
"initiatorIndex", hh.hostinfo.localIndexId,
"remoteIndex", hh.hostinfo.remoteIndexId,
"durationNs", time.Since(hh.startTime).Nanoseconds(),
}
// hh.machine can be nil here if buildStage0Packet never succeeded
@@ -352,7 +352,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
)
}
hm.f.relayManager.StartRelays(hm.f, vpnIp, hostinfo, stage0)
hm.f.relayManager.StartRelays(hm.f, vpnIp, hh, stage0)
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered {
@@ -494,7 +494,6 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs,
)
}
@@ -517,7 +516,6 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs,
)
}
-1
View File
@@ -409,7 +409,6 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
"error", err,
"udpAddr", remote,
"counter", c,
"attemptedCounter", c,
)
return
}
+44 -25
View File
@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
@@ -14,6 +15,7 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -389,13 +391,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
}
func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
cs := f.pki.getCertState()
curCert := cs.getCertificate(cert.Version2)
if curCert == nil {
curCert = cs.getCertificate(cert.Version1)
}
// The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction.
// Check to see if that set has changed, and if so, rebuild the firewall.
certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks)
if !c.HasChanged("firewall") && !certUnsafeChanged {
f.l.Debug("No firewall config change detected")
return
}
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
fw, err := NewFirewallFromConfig(f.l, cs, c)
if err != nil {
f.l.Error("Error while creating firewall during reload", "error", err)
return
@@ -507,33 +518,41 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
emit := func() {
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
certState := f.pki.getCertState()
defaultCrt := certState.GetDefaultCertificate()
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
certInitiatingVersion.Update(int64(defaultCrt.Version()))
if f.udpRaw != nil {
if rawStats == nil {
rawStats = udp.NewRawStatsEmitter(f.udpRaw)
}
rawStats()
}
// Report the max certificate version we are capable of using
if certState.v2Cert != nil {
certMaxVersion.Update(int64(certState.v2Cert.Version()))
} else {
certMaxVersion.Update(int64(certState.v1Cert.Version()))
}
}
// Prime gauges so a Prometheus scrape that lands before the first tick
// sees real values instead of the zero defaults (issue #907).
emit()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
certState := f.pki.getCertState()
defaultCrt := certState.GetDefaultCertificate()
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
certInitiatingVersion.Update(int64(defaultCrt.Version()))
if f.udpRaw != nil {
if rawStats == nil {
rawStats = udp.NewRawStatsEmitter(f.udpRaw)
}
rawStats()
}
// Report the max certificate version we are capable of using
if certState.v2Cert != nil {
certMaxVersion.Update(int64(certState.v2Cert.Version()))
} else {
certMaxVersion.Update(int64(certState.v1Cert.Version()))
}
emit()
}
}
}
+73
View File
@@ -0,0 +1,73 @@
//go:build linux || darwin
package nebula
import (
"context"
"net/netip"
"testing"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/overlay/overlaytest"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that
// landed before the first ticker fire used to read 0 for the cert gauges.
// emitStats now primes the gauges before entering the ticker loop. We assert
// the gauge is zero before the first call and non-zero after.
func Test_emitStats_primesGauges(t *testing.T) {
defer metrics.DefaultRegistry.UnregisterAll()
l := test.NewLogger()
hostMap := newHostMap(l)
preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
hostMap.preferredRanges.Store(&preferredRanges)
notAfter := time.Now().Add(time.Hour)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter},
v1Credential: nil,
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
// On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to
// *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn,
// returns an error, and the emitter falls through to a no-op.
writers: []udp.Conn{&udp.StdConn{}},
}
ifce.pki.cs.Store(cs)
ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs")
// Pre-cancel the context so emitStats returns after priming the gauges
// without ever reading from ticker.C. The one hour interval is just a
// belt-and-suspenders, the test does not expect the ticker to fire.
ctx, cancel := context.WithCancel(context.Background())
cancel()
ifce.emitStats(ctx, time.Hour)
ttl := ttlGauge.Value()
assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick")
assert.LessOrEqual(t, ttl, int64(3600))
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value())
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value())
}
+120
View File
@@ -0,0 +1,120 @@
package nebula
import (
"net/netip"
"testing"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall
// rebuilds the firewall when only the certificate's UnsafeNetworks have changed,
// even if the firewall section of the YAML has not.
func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) {
l := test.NewLogger()
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
// dummyCert avoids dragging the real signing pipeline into a unit test.
c1 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: initialUnsafe,
}
pki := &PKI{}
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
rawYAML := `firewall:
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
`
cfg := config.NewC(l)
require.NoError(t, cfg.LoadString(rawYAML))
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
require.NoError(t, err)
require.Equal(t, initialUnsafe, fw.unsafeNetworks)
f := &Interface{
pki: pki,
firewall: fw,
l: l,
}
// Swap the cert with a different UnsafeNetworks set.
newUnsafe := []netip.Prefix{
netip.MustParsePrefix("198.51.100.0/24"),
netip.MustParsePrefix("203.0.113.0/24"),
}
c2 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: newUnsafe,
}
pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2})
// Reload with the same YAML so HasChanged("firewall") reports false.
require.NoError(t, cfg.ReloadConfigString(rawYAML))
require.False(t, cfg.HasChanged("firewall"))
f.reloadFirewall(cfg)
assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced")
assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks)
assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5")))
}
// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when
// neither the firewall config nor the cert's UnsafeNetworks have changed.
func TestReloadFirewall_NoChange(t *testing.T) {
l := test.NewLogger()
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
c1 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: unsafe,
}
pki := &PKI{}
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
rawYAML := `firewall:
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
`
cfg := config.NewC(l)
require.NoError(t, cfg.LoadString(rawYAML))
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
require.NoError(t, err)
f := &Interface{
pki: pki,
firewall: fw,
l: l,
}
require.NoError(t, cfg.ReloadConfigString(rawYAML))
f.reloadFirewall(cfg)
assert.Same(t, fw, f.firewall, "firewall should not have been replaced")
}
+27 -50
View File
@@ -15,7 +15,6 @@ import (
"time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
@@ -35,7 +34,6 @@ type LightHouse struct {
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Lite
punchConn udp.Conn
punchy *Punchy
// Local cache of answers from light houses
@@ -75,9 +73,8 @@ type LightHouse struct {
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *slog.Logger
metrics *MessageMetrics
l *slog.Logger
}
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
@@ -105,7 +102,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
myVpnNetworksTable: cs.myVpnNetworksTable,
addrMap: make(map[netip.Addr]*RemoteList),
nebulaPort: nebulaPort,
punchConn: pc,
punchy: p,
updateTrigger: make(chan struct{}, 1),
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
@@ -118,9 +114,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
if c.GetBool("stats.lighthouse_metrics", false) {
h.metrics = newLighthouseMetrics()
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
} else {
h.metricHolepunchTx = metrics.NilCounter{}
}
err := h.reload(c, true)
@@ -279,16 +272,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
// Clean up. Entries still in the static_host_map will be re-built.
// Entries no longer present must have their (possible) background DNS goroutines stopped.
if existingStaticList := lh.staticList.Load(); existingStaticList != nil {
ourselves := lh.myVpnNetworks[0].Addr()
oldStaticList := lh.staticList.Load()
if oldStaticList != nil {
lh.RLock()
for staticVpnAddr := range *existingStaticList {
for staticVpnAddr := range *oldStaticList {
if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil {
am.hr.Cancel()
am.ResetForOwner(ourselves)
}
}
lh.RUnlock()
}
// Build a new list based on current config.
staticList := make(map[netip.Addr]struct{})
err := lh.loadStaticMap(c, staticList)
@@ -296,6 +291,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return err
}
// For entries removed from static_host_map, stop the DNS goroutine and drop the cached addrs.
// All addrs must come from the lighthouses now that it's no longer a static host.
if oldStaticList != nil {
lh.RLock()
for staticVpnAddr := range *oldStaticList {
if _, stillStatic := staticList[staticVpnAddr]; stillStatic {
continue
}
if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil {
am.ClearHostnameResults()
}
}
lh.RUnlock()
}
lh.staticList.Store(&staticList)
if !initial {
if c.HasChanged("static_host_map") {
@@ -1406,58 +1416,25 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
return
}
empty := []byte{0}
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
if !vpnPeer.IsValid() {
return
}
go func() {
time.Sleep(lhh.lh.punchy.GetDelay())
lhh.lh.metricHolepunchTx.Inc(1)
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
}()
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Punching",
"vpnPeer", vpnPeer,
"logVpnAddr", logVpnAddr,
)
}
}
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts {
b := protoV4AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
lhh.lh.punchy.Schedule(b, detailsVpnAddr)
}
}
for _, a := range n.Details.V6AddrPorts {
b := protoV6AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
lhh.lh.punchy.Schedule(b, detailsVpnAddr)
}
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lhh.lh.punchy.GetRespond() {
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Sending a nebula test packet",
"vpnAddr", detailsVpnAddr,
)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
}
// a tunnel. ScheduleRespond is a no-op when punchy.respond is disabled.
lhh.lh.punchy.ScheduleRespond(detailsVpnAddr)
}
func protoAddrToNetAddr(addr *Addr) netip.Addr {
+126
View File
@@ -303,6 +303,132 @@ func TestLighthouse_reload(t *testing.T) {
require.NoError(t, err)
}
// TestLighthouse_reloadStaticHostMap verifies that reloading static_host_map applies the new
// config rather than appending to it. See issue #718.
func TestLighthouse_reloadStaticHostMap(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
"10.128.0.2": []any{"1.1.1.1:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
require.NoError(t, err)
staticHost := netip.MustParseAddr("10.128.0.2")
otherHost := netip.MustParseAddr("10.128.0.3")
// Capture the RemoteList pointer up front; an in-flight handshake would hold the same one
// on hostinfo.remotes, so it must reflect every reload below.
pinned := lh.Query(staticHost)
require.NotNil(t, pinned)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, pinned.CopyAddrs([]netip.Prefix{}))
// Replace the remote address. The new address should be the only entry.
nc := map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"2.2.2.2:4242"},
},
}
rc, err := yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
rl := lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl, "RemoteList pointer must stay stable so in-flight handshakes pick up the change")
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Reload back to the original IP. Mirrors the round-trip in issue #718 step 6-8 where
// the buggy reload produced [1.1.1.1, 2.2.2.2, 1.1.1.1] instead of [1.1.1.1].
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"1.1.1.1:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Reload with the same config. An unchanged entry must not duplicate.
require.NoError(t, c.ReloadConfigString(string(rc)))
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Switch back to 2.2.2.2 so the rest of the test continues against a known address.
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"2.2.2.2:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
// Add a second host alongside the first. Both should be present, neither duplicated.
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.2": []any{"2.2.2.2:4242"},
"10.128.0.3": []any{"3.3.3.3:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl, "adding a sibling entry must not displace the existing RemoteList")
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{}))
rl = lh.Query(otherHost)
require.NotNil(t, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{}))
// Drop the first host entirely. The vpnAddr is no longer marked static, our owner
// contribution is cleared, but the addrMap entry stays in place so non-static cache
// data (from lighthouse queries) on the same RemoteList isn't lost. In-flight handshakes
// that already had the pointer see an empty address list rather than retrying stale ones.
nc = map[string]any{
"static_host_map": map[string]any{
"10.128.0.3": []any{"3.3.3.3:4242"},
},
}
rc, err = yaml.Marshal(nc)
require.NoError(t, err)
require.NoError(t, c.ReloadConfigString(string(rc)))
_, isStatic := lh.GetStaticHostList()[staticHost]
assert.False(t, isStatic)
rl = lh.Query(staticHost)
require.NotNil(t, rl)
assert.Same(t, pinned, rl)
assert.Empty(t, rl.CopyAddrs([]netip.Prefix{}))
rl = lh.Query(otherHost)
require.NotNil(t, rl)
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{}))
}
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
+5 -3
View File
@@ -55,7 +55,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
}
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
ssh, err := sshd.NewSSHServer(ctx, l.With("subsystem", "sshd"))
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
}
@@ -170,7 +170,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
}
hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c, udpConns[0])
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil {
@@ -194,7 +194,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
ds, err := newDnsServerFromConfig(ctx, l, pki, hostMap, c)
if err != nil {
l.Warn("Failed to start DNS responder", "error", err)
}
@@ -273,6 +273,8 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
handshakeManager.f = ifce
go handshakeManager.Run(ctx)
punchy.Start(ctx, ifce, hostMap, lightHouse)
}
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
-73
View File
@@ -1,73 +0,0 @@
package nebula
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
type endianness interface {
PutUint64(b []byte, v uint64)
}
var noiseEndianness endianness = binary.BigEndian
type NebulaCipherState struct {
c cipher.AEAD
}
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
x := s.Cipher()
return &NebulaCipherState{c: x.(cipher.AEAD)}
}
// EncryptDanger encrypts and authenticates a given payload.
//
// out is a destination slice to hold the output of the EncryptDanger operation.
// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
// - plaintext is encrypted, authenticated and appended to out.
// - n is a nonce value which must never be re-used with this key.
// - nb is a buffer used for temporary storage in the implementation of this call, which should
// be re-used by callers to minimize garbage collection.
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
// TODO: Is this okay now that we have made messageCounter atomic?
// Alternative may be to split the counter space into ranges
//if n <= s.n {
// return nil, errors.New("CRITICAL: a duplicate counter value was used")
//}
//s.n = n
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
out = s.c.Seal(out, nb, plaintext, ad)
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
return out, nil
} else {
return nil, errors.New("no cipher state available to encrypt")
}
}
func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
return s.c.Open(out, nb, ciphertext, ad)
} else {
return []byte{}, nil
}
}
func (s *NebulaCipherState) Overhead() int {
if s != nil {
return s.c.Overhead()
}
return 0
}
+53
View File
@@ -0,0 +1,53 @@
package noiseutil
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
// CipherStateAESGCM is the data-plane wrapper for the AES-GCM AEAD cipher.
// AES-GCM uses big-endian nonce encoding per the Noise spec.
type CipherStateAESGCM struct {
c cipher.AEAD
}
// NewCipherStateAESGCM extracts the underlying AEAD from the post-handshake noise.CipherState.
// The caller is responsible for ensuring the noise cipher is actually AES-GCM,
// otherwise the type assertion still succeeds but the nonce endianness will be wrong on the wire.
func NewCipherStateAESGCM(s *noise.CipherState) *CipherStateAESGCM {
return &CipherStateAESGCM{c: s.Cipher().(cipher.AEAD)}
}
func (s *CipherStateAESGCM) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return nil, errors.New("no cipher state available to encrypt")
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.BigEndian.PutUint64(nb[4:], n)
return s.c.Seal(out, nb, plaintext, ad), nil
}
func (s *CipherStateAESGCM) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return []byte{}, nil
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.BigEndian.PutUint64(nb[4:], n)
return s.c.Open(out, nb, ciphertext, ad)
}
func (s *CipherStateAESGCM) Overhead() int {
if s == nil {
return 0
}
return s.c.Overhead()
}
+52
View File
@@ -0,0 +1,52 @@
package noiseutil
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
// CipherStateChaChaPoly is the data-plane wrapper for the ChaCha20-Poly1305 AEAD cipher.
// ChaCha20-Poly1305 uses little-endian nonce encoding per the Noise spec.
type CipherStateChaChaPoly struct {
c cipher.AEAD
}
// NewCipherStateChaChaPoly extracts the underlying AEAD from the post-handshake noise.CipherState.
// The caller is responsible for ensuring the noise cipher is actually ChaCha20-Poly1305.
func NewCipherStateChaChaPoly(s *noise.CipherState) *CipherStateChaChaPoly {
return &CipherStateChaChaPoly{c: s.Cipher().(cipher.AEAD)}
}
func (s *CipherStateChaChaPoly) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return nil, errors.New("no cipher state available to encrypt")
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.LittleEndian.PutUint64(nb[4:], n)
return s.c.Seal(out, nb, plaintext, ad), nil
}
func (s *CipherStateChaChaPoly) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s == nil {
return []byte{}, nil
}
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
binary.LittleEndian.PutUint64(nb[4:], n)
return s.c.Open(out, nb, ciphertext, ad)
}
func (s *CipherStateChaChaPoly) Overhead() int {
if s == nil {
return 0
}
return s.c.Overhead()
}
+40
View File
@@ -0,0 +1,40 @@
package noiseutil
import (
"fmt"
"github.com/flynn/noise"
)
// CipherState is the post-handshake AEAD cipher used for the data plane.
// Each supported cipher has its own concrete implementation in this package with the nonce endianness hardcoded,
// so the encrypt/decrypt fast path avoids interface dispatch on the byte order.
type CipherState interface {
// EncryptDanger encrypts and authenticates a given payload.
//
// out is a destination slice to hold the output of the EncryptDanger operation.
// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
// - plaintext is encrypted, authenticated and appended to out.
// - n is a nonce value which must never be re-used with this key.
// - nb is a scratch buffer used to assemble the nonce.
EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error)
// DecryptDanger authenticates and decrypts a given payload, with the same argument shape as EncryptDanger.
DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error)
// Overhead returns the AEAD tag size, or 0 if the receiver is nil.
Overhead() int
}
// NewCipherState wraps the post-handshake noise.CipherState in the per-cipher type that matches cipherFunc.
// cipherFunc must be the same cipher used to build the noise CipherSuite that produced s.
func NewCipherState(s *noise.CipherState, cipherFunc noise.CipherFunc) CipherState {
switch cipherFunc.CipherName() {
case CipherAESGCM.CipherName():
return NewCipherStateAESGCM(s)
case noise.CipherChaChaPoly.CipherName():
return NewCipherStateChaChaPoly(s)
default:
panic(fmt.Sprintf("noiseutil: unsupported cipher %q", cipherFunc.CipherName()))
}
}
+166
View File
@@ -0,0 +1,166 @@
package noiseutil
import (
"testing"
"github.com/flynn/noise"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCipherStateAESGCMRoundtrip(t *testing.T) {
enc, dec := buildCipherStates(t, CipherAESGCM)
roundtrip(t, NewCipherStateAESGCM(enc), NewCipherStateAESGCM(dec))
}
func TestCipherStateChaChaPolyRoundtrip(t *testing.T) {
enc, dec := buildCipherStates(t, noise.CipherChaChaPoly)
roundtrip(t, NewCipherStateChaChaPoly(enc), NewCipherStateChaChaPoly(dec))
}
func TestNewCipherStateDispatch(t *testing.T) {
encA, _ := buildCipherStates(t, CipherAESGCM)
encC, _ := buildCipherStates(t, noise.CipherChaChaPoly)
assert.IsType(t, &CipherStateAESGCM{}, NewCipherState(encA, CipherAESGCM))
assert.IsType(t, &CipherStateChaChaPoly{}, NewCipherState(encC, noise.CipherChaChaPoly))
}
func TestNewCipherStateUnsupportedPanics(t *testing.T) {
enc, _ := buildCipherStates(t, CipherAESGCM)
assert.Panics(t, func() {
NewCipherState(enc, fakeCipher{})
})
}
type fakeCipher struct{}
func (fakeCipher) Cipher(k [32]byte) noise.Cipher { return nil }
func (fakeCipher) CipherName() string { return "Fake" }
// buildCipherStates runs an in-memory NN handshake with the requested cipher
// to produce a pair of post-handshake CipherStates that share keys.
func buildCipherStates(t *testing.T, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) {
t.Helper()
suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256)
cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN}
cfg.Initiator = true
hsI, err := noise.NewHandshakeState(cfg)
require.NoError(t, err)
cfg.Initiator = false
hsR, err := noise.NewHandshakeState(cfg)
require.NoError(t, err)
msg, _, _, err := hsI.WriteMessage(nil, nil)
require.NoError(t, err)
_, _, _, err = hsR.ReadMessage(nil, msg)
require.NoError(t, err)
msg, dR, _, err := hsR.WriteMessage(nil, nil)
require.NoError(t, err)
_, eI, _, err := hsI.ReadMessage(nil, msg)
require.NoError(t, err)
require.NotNil(t, eI)
require.NotNil(t, dR)
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
return eI, dR
}
func roundtrip(t *testing.T, enc, dec CipherState) {
t.Helper()
plaintext := []byte("nebula cipher state roundtrip")
ad := []byte("aad")
nb := make([]byte, 12)
ct, err := enc.EncryptDanger(nil, ad, plaintext, 1, nb)
require.NoError(t, err)
assert.NotEqual(t, plaintext, ct)
pt, err := dec.DecryptDanger(nil, ad, ct, 1, nb)
require.NoError(t, err)
assert.Equal(t, plaintext, pt)
// Wrong nonce must fail authentication.
_, err = dec.DecryptDanger(nil, ad, ct, 2, nb)
require.Error(t, err)
assert.Equal(t, enc.Overhead(), dec.Overhead())
assert.Equal(t, 16, enc.Overhead())
}
func BenchmarkCipherStateEncryptAESGCM(b *testing.B) {
enc, _ := buildCipherStatesB(b, CipherAESGCM)
benchEncryptCipherState(b, NewCipherState(enc, CipherAESGCM))
}
func BenchmarkCipherStateEncryptChaChaPoly(b *testing.B) {
enc, _ := buildCipherStatesB(b, noise.CipherChaChaPoly)
benchEncryptCipherState(b, NewCipherState(enc, noise.CipherChaChaPoly))
}
func benchEncryptCipherState(b *testing.B, cs CipherState) {
plaintext := make([]byte, 1280)
ad := make([]byte, 16)
nb := make([]byte, 12)
out := make([]byte, 0, len(plaintext)+cs.Overhead())
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
var err error
out, err = cs.EncryptDanger(out[:0], ad, plaintext, uint64(i+1), nb)
if err != nil {
b.Fatal(err)
}
}
}
func buildCipherStatesB(b *testing.B, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) {
b.Helper()
suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256)
cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN}
cfg.Initiator = true
hsI, err := noise.NewHandshakeState(cfg)
if err != nil {
b.Fatal(err)
}
cfg.Initiator = false
hsR, err := noise.NewHandshakeState(cfg)
if err != nil {
b.Fatal(err)
}
msg, _, _, err := hsI.WriteMessage(nil, nil)
if err != nil {
b.Fatal(err)
}
if _, _, _, err := hsR.ReadMessage(nil, msg); err != nil {
b.Fatal(err)
}
msg, dR, _, err := hsR.WriteMessage(nil, nil)
if err != nil {
b.Fatal(err)
}
_, eI, _, err := hsI.ReadMessage(nil, msg)
if err != nil {
b.Fatal(err)
}
return eI, dR
}
func TestCipherStateNilSafety(t *testing.T) {
var aes *CipherStateAESGCM
_, err := aes.EncryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.Error(t, err)
out, err := aes.DecryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.NoError(t, err)
assert.Empty(t, out)
assert.Equal(t, 0, aes.Overhead())
var cc *CipherStateChaChaPoly
_, err = cc.EncryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.Error(t, err)
out, err = cc.DecryptDanger(nil, nil, nil, 0, make([]byte, 12))
require.NoError(t, err)
assert.Empty(t, out)
assert.Equal(t, 0, cc.Overhead())
}
+2 -3
View File
@@ -194,8 +194,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender,
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
"vpnAddrs", hostinfo.vpnAddrs,
"remoteIndex", h.RemoteIndex,
"relayRemoteIndex", h.RemoteIndex,
)
return
}
@@ -218,8 +217,8 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender,
if err != nil {
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
"relayTo", relay.PeerAddr,
"relayFrom", hostinfo.vpnAddrs[0],
"error", err,
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
)
return
}
+358
View File
@@ -0,0 +1,358 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"errors"
"fmt"
"log/slog"
"runtime"
"strings"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
// networkCategory mirrors NLM_NETWORK_CATEGORY from netlistmgr.h.
type networkCategory int32
const (
networkCategoryPublic networkCategory = 0
networkCategoryPrivate networkCategory = 1
networkCategoryDomainAuthenticated networkCategory = 2
)
func (c networkCategory) String() string {
switch c {
case networkCategoryPublic:
return "public"
case networkCategoryPrivate:
return "private"
case networkCategoryDomainAuthenticated:
return "domain"
}
return fmt.Sprintf("unknown(%d)", c)
}
// parseNetworkCategory accepts the user-supplied tun.network_category. A
// second return of false means "leave the category alone".
func parseNetworkCategory(s string) (networkCategory, bool, error) {
switch strings.ToLower(strings.TrimSpace(s)) {
case "", "unset":
return 0, false, nil
case "public":
return networkCategoryPublic, true, nil
case "private":
return networkCategoryPrivate, true, nil
case "domain", "domainauthenticated":
return networkCategoryDomainAuthenticated, true, nil
}
return 0, false, fmt.Errorf("unknown tun.network_category %q (expected public, private, domain, or unset)", s)
}
// CLSID_NetworkListManager {DCB00C01-570F-4A9B-8D69-199FDBA5723B}
var clsidNetworkListManager = windows.GUID{
Data1: 0xDCB00C01, Data2: 0x570F, Data3: 0x4A9B,
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
}
// IID_INetworkListManager {DCB00000-570F-4A9B-8D69-199FDBA5723B}
var iidINetworkListManager = windows.GUID{
Data1: 0xDCB00000, Data2: 0x570F, Data3: 0x4A9B,
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
}
// x/sys/windows doesn't expose CoCreateInstance, so we bind it ourselves.
var procCoCreateInstance = windows.NewLazySystemDLL("ole32.dll").NewProc("CoCreateInstance")
const clsCtxAll = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER |
windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER
const (
hrSFALSE = 0x00000001
hrRPCEChangedMode = 0x80010106
)
type hresult uint32
func (h hresult) failed() bool { return int32(h) < 0 }
func (h hresult) String() string {
return fmt.Sprintf("HRESULT 0x%08x", uint32(h))
}
var errAdapterNotFound = errors.New("adapter not present in network connections enumeration")
// Vtable layouts. Slot order must match the declaration order in netlistmgr.h.
// All NLM interfaces here derive from IDispatch, which derives from IUnknown.
type iUnknownVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
}
type iDispatchVtbl struct {
iUnknownVtbl
GetTypeInfoCount uintptr
GetTypeInfo uintptr
GetIDsOfNames uintptr
Invoke uintptr
}
type iNetworkListManagerVtbl struct {
iDispatchVtbl
GetNetworks uintptr
GetNetwork uintptr
GetNetworkConnections uintptr
GetNetworkConnection uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
}
type iNetworkListManager struct{ Vtbl *iNetworkListManagerVtbl }
func (n *iNetworkListManager) Release() {
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
}
func (n *iNetworkListManager) GetNetworkConnections() (*iEnumNetworkConnections, error) {
var enum *iEnumNetworkConnections
r1, _, _ := syscall.SyscallN(n.Vtbl.GetNetworkConnections,
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&enum)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("INetworkListManager.GetNetworkConnections: %s", hr)
}
return enum, nil
}
type iEnumNetworkConnectionsVtbl struct {
iDispatchVtbl
NewEnum uintptr
Next uintptr
Skip uintptr
Reset uintptr
Clone uintptr
}
type iEnumNetworkConnections struct{ Vtbl *iEnumNetworkConnectionsVtbl }
func (e *iEnumNetworkConnections) Release() {
syscall.SyscallN(e.Vtbl.Release, uintptr(unsafe.Pointer(e)))
}
// Next returns the next connection, or (nil, nil) at the end of the enumeration.
func (e *iEnumNetworkConnections) Next() (*iNetworkConnection, error) {
var conn *iNetworkConnection
var fetched uint32
r1, _, _ := syscall.SyscallN(e.Vtbl.Next,
uintptr(unsafe.Pointer(e)), 1,
uintptr(unsafe.Pointer(&conn)), uintptr(unsafe.Pointer(&fetched)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("IEnumNetworkConnections.Next: %s", hr)
}
if fetched == 0 {
return nil, nil
}
return conn, nil
}
type iNetworkConnectionVtbl struct {
iDispatchVtbl
GetNetwork uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
GetConnectionId uintptr
GetAdapterId uintptr
GetDomainType uintptr
}
type iNetworkConnection struct{ Vtbl *iNetworkConnectionVtbl }
func (c *iNetworkConnection) Release() {
syscall.SyscallN(c.Vtbl.Release, uintptr(unsafe.Pointer(c)))
}
func (c *iNetworkConnection) GetAdapterId() (windows.GUID, error) {
var g windows.GUID
r1, _, _ := syscall.SyscallN(c.Vtbl.GetAdapterId,
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&g)),
)
if hr := hresult(r1); hr.failed() {
return windows.GUID{}, fmt.Errorf("INetworkConnection.GetAdapterId: %s", hr)
}
return g, nil
}
func (c *iNetworkConnection) GetNetwork() (*iNetwork, error) {
var net *iNetwork
r1, _, _ := syscall.SyscallN(c.Vtbl.GetNetwork,
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&net)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("INetworkConnection.GetNetwork: %s", hr)
}
return net, nil
}
type iNetworkVtbl struct {
iDispatchVtbl
GetName uintptr
SetName uintptr
GetDescription uintptr
SetDescription uintptr
GetNetworkId uintptr
GetDomainType uintptr
GetNetworkConnections uintptr
GetTimeCreatedAndConnected uintptr
IsConnectedToInternet uintptr
IsConnected uintptr
GetConnectivity uintptr
GetCategory uintptr
SetCategory uintptr
}
type iNetwork struct{ Vtbl *iNetworkVtbl }
func (n *iNetwork) Release() {
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
}
func (n *iNetwork) GetCategory() (networkCategory, error) {
var c networkCategory
r1, _, _ := syscall.SyscallN(n.Vtbl.GetCategory,
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&c)),
)
if hr := hresult(r1); hr.failed() {
return 0, fmt.Errorf("INetwork.GetCategory: %s", hr)
}
return c, nil
}
func (n *iNetwork) SetCategory(c networkCategory) error {
r1, _, _ := syscall.SyscallN(n.Vtbl.SetCategory,
uintptr(unsafe.Pointer(n)), uintptr(int32(c)),
)
if hr := hresult(r1); hr.failed() {
return fmt.Errorf("INetwork.SetCategory: %s", hr)
}
return nil
}
// coInit initializes COM for the current OS thread. The returned function must
// be deferred to balance a successful init. RPC_E_CHANGED_MODE means COM is
// already initialized in a different mode on this thread, which is still fine
// for our calls but we must not Uninitialize in that case.
func coInit() (func(), error) {
err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED)
if err == nil {
return windows.CoUninitialize, nil
}
if e, ok := err.(syscall.Errno); ok {
switch uint32(e) {
case hrSFALSE:
return windows.CoUninitialize, nil
case hrRPCEChangedMode:
return func() {}, nil
}
}
return nil, fmt.Errorf("CoInitializeEx: %w", err)
}
func createNetworkListManager() (*iNetworkListManager, error) {
var nlm *iNetworkListManager
r1, _, _ := procCoCreateInstance.Call(
uintptr(unsafe.Pointer(&clsidNetworkListManager)),
0,
uintptr(clsCtxAll),
uintptr(unsafe.Pointer(&iidINetworkListManager)),
uintptr(unsafe.Pointer(&nlm)),
)
if hr := hresult(r1); hr.failed() {
return nil, fmt.Errorf("CoCreateInstance(NetworkListManager): %s", hr)
}
return nlm, nil
}
// setNetworkCategory locates the network connection bound to adapterGUID and
// sets the category of its parent network. Returns errAdapterNotFound if the
// adapter is not yet visible in the NLM enumeration.
func setNetworkCategory(adapterGUID windows.GUID, cat networkCategory) error {
deinit, err := coInit()
if err != nil {
return err
}
defer deinit()
nlm, err := createNetworkListManager()
if err != nil {
return err
}
defer nlm.Release()
enum, err := nlm.GetNetworkConnections()
if err != nil {
return err
}
defer enum.Release()
for {
conn, err := enum.Next()
if err != nil {
return err
}
if conn == nil {
return errAdapterNotFound
}
guid, err := conn.GetAdapterId()
if err != nil || guid != adapterGUID {
conn.Release()
continue
}
net, err := conn.GetNetwork()
conn.Release()
if err != nil {
return err
}
err = net.SetCategory(cat)
net.Release()
return err
}
}
// applyNetworkCategory polls until the wintun adapter shows up in the NLM
// enumeration, then sets the category. Intended to run in its own goroutine.
func applyNetworkCategory(l *slog.Logger, adapterGUID windows.GUID, cat networkCategory) {
// COM Init/Uninit must be paired on the same OS thread.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
const (
attempts = 30
interval = 500 * time.Millisecond
)
for i := 0; i < attempts; i++ {
err := setNetworkCategory(adapterGUID, cat)
if err == nil {
l.Info("Set Windows network category", "category", cat.String())
return
}
if !errors.Is(err, errAdapterNotFound) {
l.Warn("Failed to set Windows network category", "error", err, "category", cat.String())
return
}
time.Sleep(interval)
}
l.Warn("Gave up waiting for adapter to appear in NLM enumeration; network category not set",
"category", cat.String(),
"waited", time.Duration(attempts)*interval,
)
}
+109
View File
@@ -0,0 +1,109 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"testing"
)
func Test_parseNetworkCategory(t *testing.T) {
cases := []struct {
in string
wantCat networkCategory
wantApply bool
wantErr bool
}{
{"", 0, false, false},
{"unset", 0, false, false},
{" UNSET ", 0, false, false},
{"private", networkCategoryPrivate, true, false},
{"Private", networkCategoryPrivate, true, false},
{" PRIVATE ", networkCategoryPrivate, true, false},
{"public", networkCategoryPublic, true, false},
{"PUBLIC", networkCategoryPublic, true, false},
{"domain", networkCategoryDomainAuthenticated, true, false},
{"DomainAuthenticated", networkCategoryDomainAuthenticated, true, false},
{"garbage", 0, false, true},
{"privates", 0, false, true},
}
for _, tc := range cases {
cat, apply, err := parseNetworkCategory(tc.in)
if (err != nil) != tc.wantErr {
t.Errorf("parseNetworkCategory(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr)
continue
}
if cat != tc.wantCat || apply != tc.wantApply {
t.Errorf("parseNetworkCategory(%q) = (%v, %v), want (%v, %v)", tc.in, cat, apply, tc.wantCat, tc.wantApply)
}
}
}
// Test_NLM_round_trip exercises every COM call path used by setNetworkCategory
// without mutating the host's network state. It validates the CLSID/IID
// constants and every vtable index by enumerating connections, fetching the
// adapter id and parent network, reading the current category, and writing it
// back unchanged.
//
// Requires Windows but does not require admin or the wintun driver. Skips if
// no network connections are available (unlikely outside of an isolated
// container).
func Test_NLM_round_trip(t *testing.T) {
deinit, err := coInit()
if err != nil {
t.Fatalf("coInit: %v", err)
}
defer deinit()
nlm, err := createNetworkListManager()
if err != nil {
t.Fatalf("createNetworkListManager: %v", err)
}
defer nlm.Release()
enum, err := nlm.GetNetworkConnections()
if err != nil {
t.Fatalf("GetNetworkConnections: %v", err)
}
defer enum.Release()
saw := 0
for {
conn, err := enum.Next()
if err != nil {
t.Fatalf("EnumNetworkConnections.Next: %v", err)
}
if conn == nil {
break
}
saw++
if _, err := conn.GetAdapterId(); err != nil {
conn.Release()
t.Fatalf("INetworkConnection.GetAdapterId: %v", err)
}
net, err := conn.GetNetwork()
conn.Release()
if err != nil {
t.Fatalf("INetworkConnection.GetNetwork: %v", err)
}
cat, err := net.GetCategory()
if err != nil {
net.Release()
t.Fatalf("INetwork.GetCategory: %v", err)
}
// Set to the current value so the host's NLM state is unchanged but
// SetCategory's vtable slot is still validated end-to-end.
if err := net.SetCategory(cat); err != nil {
net.Release()
t.Fatalf("INetwork.SetCategory(%v): %v", cat, err)
}
net.Release()
}
if saw == 0 {
t.Skip("no NLM network connections available; skipping round-trip")
}
}
+23
View File
@@ -0,0 +1,23 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
package overlay
import (
"log/slog"
"github.com/slackhq/nebula/wfp"
)
// installInterfaceBypass installs a WFP PERMIT filter scoped to the wintun interface LUID so inbound traffic on the
// nebula adapter bypasses Windows Defender Firewall.
func installInterfaceBypass(l *slog.Logger, luid uint64) closer {
s, err := wfp.PermitInterface(luid)
if err != nil {
l.Warn("Failed to install WFP bypass filters on nebula interface", "error", err)
return nil
}
l.Info("Installed WFP filters bypassing Windows Defender Firewall on nebula interface")
return s
}
+11
View File
@@ -0,0 +1,11 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import "log/slog"
// installInterfaceBypass is a no-op on windows-386 because we don't currently build for it.
func installInterfaceBypass(_ *slog.Logger, _ uint64) closer {
return nil
}
+54 -16
View File
@@ -25,15 +25,24 @@ import (
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
type closer interface {
Close()
}
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *slog.Logger
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
guid windows.GUID
networkCategory networkCategory
setCategory bool
bypassWDF bool
wdfBypass closer
l *slog.Logger
tun *wintun.NativeTun
}
@@ -54,11 +63,20 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
cat, setCat, err := parseNetworkCategory(c.GetString("tun.network_category", "private"))
if err != nil {
return nil, err
}
t := &winTun{
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
guid: *guid,
networkCategory: cat,
setCategory: setCat,
bypassWDF: c.GetBool("tun.windows_bypass_wdf", true),
l: l,
}
err = t.reload(c, true)
@@ -142,6 +160,17 @@ func (t *winTun) Activate() error {
return err
}
if t.setCategory {
// The wintun adapter takes a moment to register with the Network List
// Manager, so we apply the category in the background and retry until
// it shows up.
go applyNetworkCategory(t.l, t.guid, t.networkCategory)
}
if t.bypassWDF {
t.wdfBypass = installInterfaceBypass(t.l, uint64(t.tun.LUID()))
}
return nil
}
@@ -156,11 +185,8 @@ func (t *winTun) addRoutes(logErrors bool) error {
continue
}
// Add our unsafe route
// Windows does not support multipath routes natively, so we install only a single route.
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
// Add our unsafe route as an on-link route to the nebula tun device.
err := luid.AddRoute(r.Cidr, unspecifiedNextHop(r.Cidr), uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
@@ -206,7 +232,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
}
// See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
err := luid.DeleteRoute(r.Cidr, unspecifiedNextHop(r.Cidr))
if err != nil {
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
@@ -258,9 +284,21 @@ func (t *winTun) Close() error {
_ = luid.FlushDNS(windows.AF_INET)
_ = luid.FlushDNS(windows.AF_INET6)
if t.wdfBypass != nil {
t.wdfBypass.Close()
t.wdfBypass = nil
}
return t.tun.Close()
}
func unspecifiedNextHop(p netip.Prefix) netip.Addr {
if p.Addr().Is4() {
return netip.IPv4Unspecified()
}
return netip.IPv6Unspecified()
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
+3 -5
View File
@@ -99,12 +99,10 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
var currentState *CertState
if initial {
cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
case "aes", "chachapoly":
// Each post-handshake CipherState in noiseutil hardcodes its own
// nonce endianness now, so there's nothing to set up here.
default:
return util.NewContextualError(
"unknown cipher",
+159 -34
View File
@@ -1,24 +1,70 @@
package nebula
import (
"context"
"log/slog"
"net/netip"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
)
// holepunchQueueSize buffers the channel that pending holepunchJobs land on after their delay timer fires.
const holepunchQueueSize = 64
// holepunchJob is one scheduled item delivered to the worker goroutine.
// - target valid -> send a UDP punch to target. vpnAddr, if set, is the peer's vpn addr carried for log context.
// - target invalid, vpnAddr valid -> send an encrypted test packet to vpnAddr (a "punchback").
type holepunchJob struct {
target netip.AddrPort
vpnAddr netip.Addr
}
// lighthouseChecker is the slice of LightHouse that Punchy actually needs.
// Defined here so Punchy doesn't take a *LightHouse dependency (LightHouse
// already holds a *Punchy, and the bidirectional pointer reference is awkward
// even within the same package). Tests can also substitute a fake.
type lighthouseChecker interface {
IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool
}
type Punchy struct {
punch atomic.Bool
respond atomic.Bool
delay atomic.Int64
respondDelay atomic.Int64
punchEverything atomic.Bool
l *slog.Logger
sched *Scheduler[holepunchJob]
punchConn udp.Conn
metricHolepunchTx metrics.Counter
metricPunchyTx metrics.Counter
ctx context.Context
ifce EncWriter
hm *HostMap
lh lighthouseChecker
l *slog.Logger
}
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
p := &Punchy{l: l}
func NewPunchyFromConfig(l *slog.Logger, c *config.C, punchConn udp.Conn) *Punchy {
p := &Punchy{
l: l,
punchConn: punchConn,
sched: NewScheduler[holepunchJob](holepunchQueueSize),
metricPunchyTx: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
}
if c.GetBool("stats.lighthouse_metrics", false) {
p.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
} else {
p.metricHolepunchTx = metrics.NilCounter{}
}
p.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
@@ -29,7 +75,7 @@ func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
}
func (p *Punchy) reload(c *config.C, initial bool) {
if initial {
if initial || c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
var yes bool
if c.IsSet("punchy.punch") {
yes = c.GetBool("punchy.punch", false)
@@ -38,16 +84,15 @@ func (p *Punchy) reload(c *config.C, initial bool) {
yes = c.GetBool("punchy", false)
}
p.punch.Store(yes)
if yes {
old := p.punch.Swap(yes)
switch {
case initial && yes:
p.l.Info("punchy enabled")
} else {
case initial:
p.l.Info("punchy disabled")
case old != yes:
p.l.Info("punchy.punch changed", "punch", yes)
}
} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
}
if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") {
@@ -59,52 +104,132 @@ func (p *Punchy) reload(c *config.C, initial bool) {
yes = c.GetBool("punch_back", false)
}
p.respond.Store(yes)
if !initial {
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
old := p.respond.Swap(yes)
if !initial && old != yes {
p.l.Info("punchy.respond changed", "respond", yes)
}
}
//NOTE: this will not apply to any in progress operations, only the next one
if initial || c.HasChanged("punchy.delay") {
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
if !initial {
p.l.Info("punchy.delay changed", "delay", p.GetDelay())
newDelay := int64(c.GetDuration("punchy.delay", time.Second))
old := p.delay.Swap(newDelay)
if !initial && old != newDelay {
p.l.Info("punchy.delay changed", "delay", time.Duration(newDelay))
}
}
if initial || c.HasChanged("punchy.target_all_remotes") {
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
if !initial {
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
yes := c.GetBool("punchy.target_all_remotes", false)
old := p.punchEverything.Swap(yes)
if !initial && old != yes {
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", yes)
}
}
if initial || c.HasChanged("punchy.respond_delay") {
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
if !initial {
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
newDelay := int64(c.GetDuration("punchy.respond_delay", 5*time.Second))
old := p.respondDelay.Swap(newDelay)
if !initial && old != newDelay {
p.l.Info("punchy.respond_delay changed", "respond_delay", time.Duration(newDelay))
}
}
}
func (p *Punchy) GetPunch() bool {
return p.punch.Load()
// Schedule queues a punch packet to target, to be sent after the configured delay.
// vpnAddr is the peer's vpn addr, used for log context when the packet actually fires.
// No-op if target is not a valid AddrPort or if Start has not yet been called. Safe to call from any goroutine.
func (p *Punchy) Schedule(target netip.AddrPort, vpnAddr netip.Addr) {
if !target.IsValid() || p.ctx == nil {
return
}
p.scheduleJob(holepunchJob{target: target, vpnAddr: vpnAddr}, time.Duration(p.delay.Load()))
}
func (p *Punchy) GetRespond() bool {
return p.respond.Load()
// ScheduleRespond queues a punchback test packet to vpnAddr after the configured respond delay,
// gated on punchy.respond. No-op when respond is disabled or before Start has been called.
func (p *Punchy) ScheduleRespond(vpnAddr netip.Addr) {
if !p.respond.Load() || p.ctx == nil {
return
}
p.scheduleJob(holepunchJob{vpnAddr: vpnAddr}, time.Duration(p.respondDelay.Load()))
}
func (p *Punchy) GetDelay() time.Duration {
return (time.Duration)(p.delay.Load())
// scheduleJob delegates to the pooled Scheduler.
// The callback observes p.ctx so a job that becomes due after Stop is dropped instead of queued.
func (p *Punchy) scheduleJob(job holepunchJob, delay time.Duration) {
p.sched.Schedule(p.ctx, job, delay)
}
func (p *Punchy) GetRespondDelay() time.Duration {
return (time.Duration)(p.respondDelay.Load())
// SendPunch sends an immediate keepalive punch for an idle hostinfo.
// The configured punchy.target_all_remotes mode picks the targets. Gated on punchy.punch and the lighthouse-skip rule
// (lighthouses don't get keepalive punches because the regular update interval keeps their NAT state warm).
func (p *Punchy) SendPunch(hostinfo *HostInfo) {
if !p.punch.Load() {
return
}
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
return
}
if p.punchEverything.Load() {
p.sendPunchToAllRemotes(hostinfo)
} else if hostinfo.remote.IsValid() {
p.metricPunchyTx.Inc(1)
p.punchConn.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (p *Punchy) GetTargetEverything() bool {
return p.punchEverything.Load()
// SendPunchToAll punches every known remote for hostinfo, but only when punchy.target_all_remotes is enabled.
// The connection manager calls this during outbound-only traffic: the outbound traffic itself keeps the primary's
// NAT state warm, but non-primary remotes need separate refresh, so we fan out to all of them (the redundant
// primary punch is harmless). Gated on punchy.punch and the lighthouse-skip rule.
func (p *Punchy) SendPunchToAll(hostinfo *HostInfo) {
if !p.punchEverything.Load() {
return
}
if !p.punch.Load() {
return
}
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
return
}
p.sendPunchToAllRemotes(hostinfo)
}
func (p *Punchy) sendPunchToAllRemotes(hostinfo *HostInfo) {
hostinfo.remotes.ForEach(p.hm.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
p.metricPunchyTx.Inc(1)
p.punchConn.WriteTo([]byte{1}, addr)
})
}
// Start wires the runtime dependencies and spawns the scheduler worker.
func (p *Punchy) Start(ctx context.Context, ifce EncWriter, hm *HostMap, lh lighthouseChecker) {
p.ctx = ctx
p.ifce = ifce
p.hm = hm
p.lh = lh
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
empty := []byte{0}
go p.sched.Run(ctx, func(job holepunchJob) {
switch {
case job.target.IsValid():
if p.l.Enabled(context.Background(), slog.LevelDebug) {
p.l.Debug("Punching", "target", job.target, "vpnAddr", job.vpnAddr)
}
p.metricHolepunchTx.Inc(1)
p.punchConn.WriteTo(empty, job.target)
case job.vpnAddr.IsValid():
// A nebula test packet to the host trying to contact us.
// In the case of a double nat or other difficult scenario, this may help establish a tunnel.
if p.l.Enabled(context.Background(), slog.LevelDebug) {
p.l.Debug("Sending a nebula test packet", "vpnAddr", job.vpnAddr)
}
p.ifce.SendMessageToVpnAddr(header.Test, header.TestRequest, job.vpnAddr, []byte(""), nb, out)
}
})
}
+40 -41
View File
@@ -17,42 +17,42 @@ func TestNewPunchyFromConfig(t *testing.T) {
c := config.NewC(l)
// Test defaults
p := NewPunchyFromConfig(test.NewLogger(), c)
assert.False(t, p.GetPunch())
assert.False(t, p.GetRespond())
assert.Equal(t, time.Second, p.GetDelay())
assert.Equal(t, 5*time.Second, p.GetRespondDelay())
p := NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.False(t, p.punch.Load())
assert.False(t, p.respond.Load())
assert.Equal(t, time.Second, time.Duration(p.delay.Load()))
assert.Equal(t, 5*time.Second, time.Duration(p.respondDelay.Load()))
// punchy deprecation
c.Settings["punchy"] = true
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.punch.Load())
// punchy.punch
c.Settings["punchy"] = map[string]any{"punch": true}
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.punch.Load())
// punch_back deprecation
c.Settings["punch_back"] = true
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.respond.Load())
// punchy.respond
c.Settings["punchy"] = map[string]any{"respond": true}
c.Settings["punch_back"] = false
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.True(t, p.respond.Load())
// punchy.delay
c.Settings["punchy"] = map[string]any{"delay": "1m"}
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetDelay())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.Equal(t, time.Minute, time.Duration(p.delay.Load()))
// punchy.respond_delay
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetRespondDelay())
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.Equal(t, time.Minute, time.Duration(p.respondDelay.Load()))
}
func TestPunchy_reload(t *testing.T) {
@@ -61,35 +61,34 @@ func TestPunchy_reload(t *testing.T) {
delay, _ := time.ParseDuration("1m")
require.NoError(t, c.LoadString(`
punchy:
punch: false
delay: 1m
respond: false
`))
p := NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, delay, p.GetDelay())
assert.False(t, p.GetRespond())
p := NewPunchyFromConfig(test.NewLogger(), c, nil)
assert.False(t, p.punch.Load())
assert.Equal(t, delay, time.Duration(p.delay.Load()))
assert.False(t, p.respond.Load())
newDelay, _ := time.ParseDuration("10m")
require.NoError(t, c.ReloadConfigString(`
punchy:
punch: true
delay: 10m
respond: true
`))
p.reload(c, false)
assert.Equal(t, newDelay, p.GetDelay())
assert.True(t, p.GetRespond())
assert.True(t, p.punch.Load())
assert.Equal(t, newDelay, time.Duration(p.delay.Load()))
assert.True(t, p.respond.Load())
}
// The tests below pin the shape of each log line Punchy produces so changes
// cannot silently break whatever operators are grepping for. The assertions
// are on the structured message + attrs (e.g. "punchy.respond changed" with
// a respond=true field) rather than a formatted string.
//
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
// not supported" warning whenever any key under punchy changes, because of
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
// punchy form. The tests filter by message rather than asserting total
// entry counts so that warning is tolerated without being locked into
// the format.
// a respond=true field) rather than a formatted string. Tests filter by
// message rather than asserting total entry counts so unrelated info lines
// are tolerated without being locked into the format.
type capturedEntry struct {
Level slog.Level
@@ -145,7 +144,7 @@ func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
entry := findEntry(t, hook.entries, "punchy enabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
@@ -157,32 +156,32 @@ func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
entry := findEntry(t, hook.entries, "punchy disabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
func TestPunchy_LogFormat_ReloadPunch(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
assert.Equal(t, slog.LevelWarn, entry.Level)
assert.Empty(t, entry.Attrs)
entry := findEntry(t, hook.entries, "punchy.punch changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"punch": true}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
@@ -196,7 +195,7 @@ func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
@@ -210,7 +209,7 @@ func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
@@ -224,7 +223,7 @@ func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
NewPunchyFromConfig(l, c)
NewPunchyFromConfig(l, c, nil)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
+37 -18
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"net/netip"
"slices"
"sync/atomic"
"github.com/slackhq/nebula/cert"
@@ -57,14 +58,25 @@ func (rm *relayManager) GetUseRelays() bool {
// For each candidate relay it either kicks off a handshake to the relay, sends a CreateRelayRequest, retransmits
// one that may have been lost, or, once the relay is Established, forwards the in-progress
// stage 0 handshake packet for vpnIp through it.
func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *HostInfo, stage0 []byte) {
func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hh *HandshakeHostInfo, stage0 []byte) {
hostinfo := hh.hostinfo
if !rm.GetUseRelays() || len(hostinfo.remotes.relays) == 0 {
hh.lastRelays = nil
return
}
hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
relays := hostinfo.remotes.relays
listLevel := slog.LevelDebug
prior := hh.lastRelays
if !slices.Equal(relays, prior) {
listLevel = slog.LevelInfo
hh.lastRelays = slices.Clone(relays)
}
hl := hostinfo.logger(rm.l)
hl.Log(context.Background(), listLevel, "Attempt to relay through hosts", "relays", relays)
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
for _, relay := range relays {
// Don't relay through the host I'm trying to connect to
if relay == vpnIp {
continue
@@ -75,12 +87,19 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
continue
}
// Each relay's per-attempt log fires at Info on the first time we hit it and Debug after that.
level := slog.LevelInfo
if slices.Contains(prior, relay) {
level = slog.LevelDebug
}
relayHostInfo := rm.hostmap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(rm.l).Info("Establish tunnel to relay target", "relay", relay.String())
hl.Log(context.Background(), level, "Establish tunnel to relay target", "relay", relay.String())
f.Handshake(relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through
existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
if !ok {
@@ -88,7 +107,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(rm.l, relayHostInfo, rm.hostmap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(rm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
hl.Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
}
m := NebulaControl{
@@ -99,12 +118,12 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !f.myVpnAddrs[0].Is4() {
hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
@@ -116,16 +135,16 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay")
hl.Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err)
hl.Error("Failed to marshal Control message to create relay", "error", err)
} else {
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.Info("send CreateRelayRequest",
rm.l.Log(context.Background(), level, "send CreateRelayRequest",
"relayFrom", f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", idx,
@@ -138,14 +157,14 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
switch existingRelay.State {
case Established:
hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String())
hl.Log(context.Background(), level, "Send handshake via relay", "relay", relay.String())
f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested:
hostinfo.logger(rm.l).Info("Re-send CreateRelay request", "relay", relay.String())
hl.Log(context.Background(), level, "Re-send CreateRelay request", "relay", relay.String())
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
@@ -155,12 +174,12 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !f.myVpnAddrs[0].Is4() {
hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
@@ -172,16 +191,16 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay")
hl.Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err)
hl.Error("Failed to marshal Control message to create relay", "error", err)
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.Info("send CreateRelayRequest",
rm.l.Log(context.Background(), level, "send CreateRelayRequest",
"relayFrom", f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", existingRelay.LocalIndex,
@@ -192,7 +211,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default:
hostinfo.logger(rm.l).Error("Relay unexpected state",
hl.Error("Relay unexpected state",
"vpnIp", vpnIp,
"state", existingRelay.State,
"relay", relay,
+97
View File
@@ -0,0 +1,97 @@
package nebula
import (
"bytes"
"log/slog"
"net/netip"
"testing"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
// TestStartRelaysLogDedupe verifies that repeated attempts with the same relay set drop the log
// chatter to Debug, mirroring how the normal handshake retry loop quiets down once it's already
// announced its targets.
func TestStartRelaysLogDedupe(t *testing.T) {
vpnIp := netip.MustParseAddr("100.64.99.4")
otherRelay := netip.MustParseAddr("100.64.99.5")
newHH := func() *HandshakeHostInfo {
// Use the target's own vpnIp as the "relay" so the loop body skips it without
// touching any sender-side state. That isolates the test to the level-selection
// behavior of the top-level "Attempt to relay through hosts" log.
hostinfo := &HostInfo{
vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1,
remotes: NewRemoteList([]netip.Addr{vpnIp}, nil),
}
hostinfo.remotes.relays = []netip.Addr{vpnIp}
return &HandshakeHostInfo{hostinfo: hostinfo}
}
// Park any extra relay addresses we'll introduce mid-test in myVpnAddrsTable so the loop
// body always skips before touching f.Handshake (which would need a real handshakeManager).
addrTable := new(bart.Lite)
addrTable.Insert(netip.PrefixFrom(otherRelay, otherRelay.BitLen()))
f := &Interface{myVpnAddrsTable: addrTable}
newRM := func(buf *bytes.Buffer) *relayManager {
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
rm := &relayManager{l: l, hostmap: newHostMap(l)}
rm.useRelays.Store(true)
return rm
}
const msg = `msg="Attempt to relay through hosts"`
t.Run("first attempt logs at Info", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
hh := newHH()
rm.StartRelays(f, vpnIp, hh, nil)
assert.Equal(t, []netip.Addr{vpnIp}, hh.lastRelays, "lastRelays should record the relay set we just attempted")
assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info level on first attempt")
})
t.Run("repeat attempt with same relays drops to Debug", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
hh := newHH()
rm.StartRelays(f, vpnIp, hh, nil)
first := append([]netip.Addr(nil), hh.lastRelays...)
buf.Reset()
rm.StartRelays(f, vpnIp, hh, nil)
assert.Equal(t, first, hh.lastRelays)
assert.Contains(t, buf.String(), "level=DEBUG "+msg, "expected Debug level on identical retry")
assert.NotContains(t, buf.String(), "level=INFO "+msg, "Info should not fire on identical retry")
})
t.Run("changed relay list bumps back to Info", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
hh := newHH()
rm.StartRelays(f, vpnIp, hh, nil)
buf.Reset()
// The lighthouse handed us a new set this round.
hh.hostinfo.remotes.relays = []netip.Addr{vpnIp, otherRelay}
rm.StartRelays(f, vpnIp, hh, nil)
assert.Equal(t, []netip.Addr{vpnIp, otherRelay}, hh.lastRelays)
assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info when the relay list changes")
})
t.Run("disabled relays clears lastRelays and emits no Attempt log", func(t *testing.T) {
var buf bytes.Buffer
rm := newRM(&buf)
rm.useRelays.Store(false)
hh := newHH()
hh.lastRelays = []netip.Addr{vpnIp}
rm.StartRelays(f, vpnIp, hh, nil)
assert.Nil(t, hh.lastRelays, "with relays disabled lastRelays should be cleared")
assert.NotContains(t, buf.String(), msg, "should not log when we shortcut out")
})
}
+25
View File
@@ -239,6 +239,31 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
r.hr = hr
}
// ResetForOwner zeros the reported address slices for the given owner and marks the addrs list dirty.
// Any pending hostname resolution will be canceled.
func (r *RemoteList) ResetForOwner(ownerVpnAddr netip.Addr) {
r.Lock()
defer r.Unlock()
r.hr.Cancel()
if c, ok := r.cache[ownerVpnAddr]; ok {
if c.v4 != nil {
c.v4.reported = c.v4.reported[:0]
}
if c.v6 != nil {
c.v6.reported = c.v6.reported[:0]
}
}
r.shouldRebuild = true
}
// ClearHostnameResults cancels the in-flight DNS resolver goroutine (if any) and drops the resolved IP cache.
func (r *RemoteList) ClearHostnameResults() {
r.Lock()
defer r.Unlock()
r.unlockedSetHostnamesResults(nil)
r.shouldRebuild = true
}
// Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
+89
View File
@@ -6,8 +6,22 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// trackedHostnameResults builds a *hostnamesResults with a known cancel function and a
// pre-populated ips map so tests can assert cancellation and verify previously-resolved
// IPs survive a cancel without spinning up a real DNS resolver.
func trackedHostnameResults(cancelFn func(), addrs ...string) *hostnamesResults {
hr := &hostnamesResults{cancelFn: cancelFn}
ips := map[netip.AddrPort]struct{}{}
for _, a := range addrs {
ips[netip.MustParseAddrPort(a)] = struct{}{}
}
hr.ips.Store(&ips)
return hr
}
func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4(
@@ -112,6 +126,81 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
}
func TestRemoteList_ResetForOwner(t *testing.T) {
ourselves := netip.MustParseAddr("10.0.0.1")
otherOwner := netip.MustParseAddr("10.0.0.2")
vpnAddr := netip.MustParseAddr("10.0.0.99")
rl := NewRemoteList([]netip.Addr{vpnAddr}, nil)
rl.unlockedSetV4(ourselves, vpnAddr,
[]*V4AddrPort{newIp4AndPortFromString("1.1.1.1:4242")},
func(netip.Addr, *V4AddrPort) bool { return true },
)
rl.unlockedSetV6(ourselves, vpnAddr,
[]*V6AddrPort{newIp6AndPortFromString("[1::1]:4242")},
func(netip.Addr, *V6AddrPort) bool { return true },
)
rl.unlockedSetV4(otherOwner, vpnAddr,
[]*V4AddrPort{newIp4AndPortFromString("2.2.2.2:4242")},
func(netip.Addr, *V4AddrPort) bool { return true },
)
canceled := 0
hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")
rl.Lock()
rl.unlockedSetHostnamesResults(hr)
rl.Unlock()
rl.ResetForOwner(ourselves)
rl.RLock()
defer rl.RUnlock()
assert.Empty(t, rl.cache[ourselves].v4.reported, "our v4 reported should be cleared")
assert.Empty(t, rl.cache[ourselves].v6.reported, "our v6 reported should be cleared")
assert.Len(t, rl.cache[otherOwner].v4.reported, 1, "other owner's contribution must be preserved")
assert.Equal(t, "2.2.2.2:4242", protoV4AddrPortToNetAddrPort(rl.cache[otherOwner].v4.reported[0]).String())
assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled")
assert.Same(t, hr, rl.hr, "hostnamesResults must be preserved so DNS-resolved IPs keep feeding addrs until replaced")
assert.NotEmpty(t, rl.hr.GetAddrs(), "previously-resolved IPs should still be readable after cancel")
assert.True(t, rl.shouldRebuild, "shouldRebuild must be set so the next Rebuild recomputes addrs")
}
func TestRemoteList_ResetForOwner_NoEntry(t *testing.T) {
// An owner with no cache entry must not panic; shouldRebuild is still set and any
// existing hostnamesResults is canceled.
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil)
canceled := 0
rl.Lock()
rl.unlockedSetHostnamesResults(trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242"))
rl.Unlock()
rl.ResetForOwner(netip.MustParseAddr("10.0.0.1"))
rl.RLock()
defer rl.RUnlock()
assert.Equal(t, 1, canceled)
assert.True(t, rl.shouldRebuild)
}
func TestRemoteList_ClearHostnameResults(t *testing.T) {
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil)
canceled := 0
hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")
rl.Lock()
rl.unlockedSetHostnamesResults(hr)
rl.Unlock()
require.NotEmpty(t, hr.GetAddrs(), "hostnamesResults should have its fastrack IPs populated")
rl.ClearHostnameResults()
rl.RLock()
defer rl.RUnlock()
assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled")
assert.Nil(t, rl.hr, "hostnamesResults should be dropped")
assert.True(t, rl.shouldRebuild)
}
func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4(
+84
View File
@@ -0,0 +1,84 @@
package nebula
import (
"context"
"sync"
"time"
)
// Scheduler is an allocation-conscious dispatch primitive for delayed work.
// Pending items are handed to time.AfterFunc, and ready items land on a worker
// channel for centralized dispatch in fire-time order.
//
// Pick a Scheduler when fire timing matters (exact deadlines, no bucketing) or when the scheduling
// rate is uneven enough that idle CPU matters. Each fire is a runtime-spawned goroutine running the callback before
// delivering to the worker, which is fine at sparse rates but adds up at line rate.
//
// Pick a TimerWheel when scheduling is high-rate and uniform: its O(1) insert, internal item cache,
// and bucket-batched dispatch are cheaper at scale.
// The caller drives the tick loop (Advance/Purge) and pays for fires at bucket boundaries rather than exact deadlines.
type Scheduler[T any] struct {
queue chan T
pool sync.Pool
}
type schedItem[T any] struct {
val T
ctx context.Context
s *Scheduler[T]
timer *time.Timer
fire func()
}
// NewScheduler builds a Scheduler whose worker channel is sized to queueSize.
// The buffer absorbs bursts of timers firing close together without
// blocking the runtime's callback goroutines on the worker.
func NewScheduler[T any](queueSize int) *Scheduler[T] {
s := &Scheduler[T]{
queue: make(chan T, queueSize),
}
s.pool.New = func() any {
si := &schedItem[T]{s: s}
// fire is allocated exactly once per pool-resident item.
// The closure captures only `si`, which stays stable for the item's lifetime.
si.fire = func() {
select {
case si.s.queue <- si.val:
case <-si.ctx.Done():
}
var zero T
si.val = zero
si.ctx = nil
si.s.pool.Put(si)
}
return si
}
return s
}
// Schedule arranges item to be delivered to the worker after delay.
// The runtime's timer heap handles the wait, so the scheduler itself burns no CPU while idle.
// The callback observes ctx: if ctx is cancelled before the timer fires, the item is dropped instead of queued.
func (s *Scheduler[T]) Schedule(ctx context.Context, item T, delay time.Duration) {
si := s.pool.Get().(*schedItem[T])
si.val = item
si.ctx = ctx
if si.timer == nil {
si.timer = time.AfterFunc(delay, si.fire)
} else {
si.timer.Reset(delay)
}
}
// Run drains the worker queue, calling fn for each item. Returns when ctx is cancelled.
// Tests that want deterministic timing should drive the queue directly rather than going through Schedule + Run.
func (s *Scheduler[T]) Run(ctx context.Context, fn func(T)) {
for {
select {
case <-ctx.Done():
return
case item := <-s.queue:
fn(item)
}
}
}
+79
View File
@@ -0,0 +1,79 @@
package nebula
import (
"context"
"testing"
"time"
)
func TestScheduler_PooledReuse(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := NewScheduler[int](16)
delivered := make(chan int, 256)
go s.Run(ctx, func(item int) { delivered <- item })
const N = 100
for i := 0; i < N; i++ {
s.Schedule(ctx, i, time.Millisecond)
}
deadline := time.After(2 * time.Second)
got := 0
for got < N {
select {
case <-delivered:
got++
case <-deadline:
t.Fatalf("only %d/%d items delivered", got, N)
}
}
}
// BenchmarkScheduler_Schedule reports allocations per Schedule call.
// In steady state the Scheduler's sync.Pool means we should see zero allocs per op once the pool warms up.
func BenchmarkScheduler_Schedule(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := NewScheduler[int](b.N)
go s.Run(ctx, func(int) {})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.Schedule(ctx, i, time.Microsecond)
}
}
// BenchmarkBareAfterFunc is the comparison baseline.
// What we'd pay per Schedule if Punchy called time.AfterFunc directly without the pooled Scheduler.
// Allocates a *time.Timer plus a closure each call.
func BenchmarkBareAfterFunc(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
queue := make(chan int, b.N)
go func() {
for {
select {
case <-ctx.Done():
return
case <-queue:
}
}
}()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
i := i
time.AfterFunc(time.Microsecond, func() {
select {
case queue <- i:
case <-ctx.Done():
}
})
}
}
+38 -21
View File
@@ -27,21 +27,20 @@ type SSHServer struct {
commands *radix.Tree
listener net.Listener
// Call the cancel() function to stop all active sessions
ctx context.Context
cancel func()
// ctx parents per-Run contexts. Cancelling it (e.g. via Control.Stop) tears the server down even
// across reloads, since each Run derives a fresh child rather than reusing this one directly.
ctx context.Context
}
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
ctx, cancel := context.WithCancel(context.Background())
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen.
// The ssh server's context is parented off the supplied ctx so cancelling it
// (e.g. on Control.Stop) tears down active sessions and closes the listener.
func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) {
s := &SSHServer{
trustedKeys: make(map[string]map[string]bool),
l: l,
commands: radix.New(),
ctx: ctx,
cancel: cancel,
}
cc := ssh.CertChecker{
@@ -151,28 +150,51 @@ func (s *SSHServer) RegisterCommand(c *Command) {
s.commands.Insert(c.Name, c)
}
// Run begins listening and accepting connections
// Run begins listening and accepting connections. Each invocation derives a fresh per-Run context
// from the constructor-supplied ctx so a Stop+Run sequence (used by config reload) starts clean
// rather than carrying a permanently-cancelled context across runs.
func (s *SSHServer) Run(addr string) error {
var err error
s.listener, err = net.Listen("tcp", addr)
if s.ctx.Err() != nil {
return s.ctx.Err()
}
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
// s.listener is the public handle Stop uses to interrupt the active run; listener (the local) is what
// this run owns. They start equal but a fast reload may overwrite s.listener with the next run's
// listener before this run's watcher fires, so each run must close its own listener via the local
// reference.
s.listener = listener
runCtx, cancel := context.WithCancel(s.ctx)
defer cancel()
// Close the listener when this run's context is cancelled. That can come from the parent
// (Control.Stop), from Run returning normally (defer cancel above), or transitively when a sibling
// run cancels through Stop closing the listener. net.Listener.Close is idempotent so a duplicate
// close from Stop is benign.
go func() {
<-runCtx.Done()
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
s.l.Warn("Failed to close the sshd listener", "error", err)
}
}()
s.l.Info("SSH server is listening", "sshListener", addr)
// Run loops until there is an error
s.run()
s.closeSessions()
s.run(runCtx, listener)
s.l.Info("SSH server stopped listening")
// We don't return an error because run logs for us
return nil
}
func (s *SSHServer) run() {
func (s *SSHServer) run(ctx context.Context, listener net.Listener) {
for {
c, err := s.listener.Accept()
c, err := listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
s.l.Warn("Error in listener, shutting down", "error", err)
@@ -184,7 +206,7 @@ func (s *SSHServer) run() {
// Ensure that a bad client doesn't hurt us by checking for the parent context
// cancellation before calling NewServerConn, and forcing the socket to close when
// the context is cancelled.
sessionContext, sessionCancel := context.WithCancel(s.ctx)
sessionContext, sessionCancel := context.WithCancel(ctx)
go func() {
<-sessionContext.Done()
c.Close()
@@ -227,14 +249,9 @@ func (s *SSHServer) run() {
}
func (s *SSHServer) Stop() {
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
if s.listener != nil {
if err := s.listener.Close(); err != nil {
s.l.Warn("Failed to close the sshd listener", "error", err)
}
}
}
func (s *SSHServer) closeSessions() {
s.cancel()
}
+17
View File
@@ -8,6 +8,23 @@ import (
// How many timer objects should be cached
const timerCacheMax = 50000
// TimerWheel is a hashed timing wheel: a fixed slot array indexed by (now + delay) % wheelLen,
// with each slot a singly linked list of items due in that bucket.
// Adds are O(1), Purges return items in arrival-within-slot order, and an internal cache of TimeoutItems
// keeps steady-state inserts allocation-free.
//
// The TimerWheel does not handle concurrency or lifecycle on its own.
// Callers drive Advance/Purge from their own ticker loop, take their own locks (or use LockingTimerWheel),
// and decide whether to keep ticking when the wheel is empty.
//
// Pick a TimerWheel when scheduling is high-rate and uniform: line-rate conntrack inserts,
// per-tunnel traffic checks at fixed intervals. O(1) insert plus the item cache means the hot path doesn't allocate.
// Items added in the same tick are dispatched together when that slot rotates current,
// which amortizes the cost of waking the worker.
//
// Pick a Scheduler when delay precision matters or scheduling is sparse or uneven.
// The wheel rounds requested timeouts up to its tick resolution and clamps anything beyond its wheel duration;
// both are silent in this implementation.
type TimerWheel[T any] struct {
// Current tick
current int
+1 -2
View File
@@ -5,12 +5,11 @@ package udp
import (
"fmt"
"log/slog"
"net"
"net/netip"
"syscall"
"log/slog"
"golang.org/x/sys/unix"
)
+1 -2
View File
@@ -8,12 +8,11 @@ package udp
import (
"fmt"
"log/slog"
"net"
"net/netip"
"syscall"
"log/slog"
"golang.org/x/sys/unix"
)
+57
View File
@@ -0,0 +1,57 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
package udp
import (
"log/slog"
"sync"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/wfp"
)
// wrapWithWDFBypass wraps a Conn so that the first ReloadConfig consults listen.windows_bypass_wdf
// and installs a WFP PERMIT filter for the listener's bound UDP port. The session is released when Close runs.
func wrapWithWDFBypass(l *slog.Logger, conn Conn) Conn {
return &bypassConn{Conn: conn, l: l}
}
type bypassConn struct {
Conn
l *slog.Logger
installOnce sync.Once
session *wfp.Session
}
func (b *bypassConn) ReloadConfig(c *config.C) {
b.installOnce.Do(func() {
if !c.GetBool("listen.windows_bypass_wdf", true) {
return
}
addr, err := b.Conn.LocalAddr()
if err != nil {
b.l.Warn("Failed to query listener port for WFP bypass", "error", err)
return
}
s, err := wfp.PermitUDPPort(addr.Port())
if err != nil {
b.l.Warn("Failed to install WFP bypass filters for listener", "error", err)
return
}
b.l.Info("Installed WFP filters bypassing Windows Defender Firewall on UDP listener port",
"port", addr.Port())
b.session = s
})
b.Conn.ReloadConfig(c)
}
func (b *bypassConn) Close() error {
if b.session != nil {
b.session.Close()
b.session = nil
}
return b.Conn.Close()
}
+11
View File
@@ -0,0 +1,11 @@
//go:build !e2e_testing
// +build !e2e_testing
package udp
import "log/slog"
// wrapWithWDFBypass is a no-op on windows-386 since we don't currently build for it.
func wrapWithWDFBypass(_ *slog.Logger, conn Conn) Conn {
return conn
}
+1 -2
View File
@@ -7,12 +7,11 @@ package udp
import (
"fmt"
"log/slog"
"net"
"net/netip"
"syscall"
"log/slog"
"golang.org/x/sys/unix"
)
+9 -4
View File
@@ -19,13 +19,18 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int)
return nil, fmt.Errorf("multiple udp listeners not supported on windows")
}
var conn Conn
rc, err := NewRIOListener(l, ip, port)
if err == nil {
return rc, nil
conn = rc
} else {
l.Error("Falling back to standard udp sockets", "error", err)
conn, err = NewGenericListener(l, ip, port, multi, batch)
if err != nil {
return nil, err
}
}
l.Error("Falling back to standard udp sockets", "error", err)
return NewGenericListener(l, ip, port, multi, batch)
return wrapWithWDFBypass(l, conn), nil
}
func NewListenConfig(multi bool) net.ListenConfig {
+377
View File
@@ -0,0 +1,377 @@
//go:build (amd64 || arm64) && !e2e_testing
// +build amd64 arm64
// +build !e2e_testing
// Package wfp installs Windows Filtering Platform (WFP) PERMIT filters in a dynamic, session-scoped sublayer.
// Because WFP sits below Windows Defender Firewall, a high-weight permit at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4/V6 lets
// the matching inbound traffic through regardless of WDF rules.
//
// Each Session owns its own engine handle. When the handle closes, every dynamic object added during the session
// is auto-deleted by Windows, so there are no orphaned filters.
//
// Type definitions and constants are derived from the wireguard-windows firewall package (MIT).
// Only the subset we exercise is reproduced.
package wfp
import (
"fmt"
"unsafe"
"golang.org/x/sys/windows"
)
// FWPM layer GUIDs (fwpmu.h).
//
// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = e1cd9fe7-f4b5-4273-96c0-592e487b8650
// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = a3b42c97-9f04-4672-b87e-cee9c483257f
var (
fwpmLayerAleAuthRecvAcceptV4 = windows.GUID{
Data1: 0xe1cd9fe7, Data2: 0xf4b5, Data3: 0x4273,
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
}
fwpmLayerAleAuthRecvAcceptV6 = windows.GUID{
Data1: 0xa3b42c97, Data2: 0x9f04, Data3: 0x4672,
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
}
)
// FWPM_CONDITION_IP_LOCAL_INTERFACE = 4cd62a49-59c3-4969-b7f3-bda5d32890a4
var fwpmConditionIPLocalInterface = windows.GUID{
Data1: 0x4cd62a49, Data2: 0x59c3, Data3: 0x4969,
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
}
// FWPM_CONDITION_IP_PROTOCOL = 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
var fwpmConditionIPProtocol = windows.GUID{
Data1: 0x3971ef2b, Data2: 0x623e, Data3: 0x4f9a,
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
}
// FWPM_CONDITION_IP_LOCAL_PORT = 0c1ba1af-5765-453f-af22-a8f791ac775b
var fwpmConditionIPLocalPort = windows.GUID{
Data1: 0x0c1ba1af, Data2: 0x5765, Data3: 0x453f,
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
}
// IPPROTO_UDP from in.h.
const ipprotoUDP uint8 = 17
// FWP_ACTION_TYPE values (fwptypes.h). PERMIT is terminating.
const fwpActionPermit uint32 = 0x00001002 // 0x2 | FWP_ACTION_FLAG_TERMINATING(0x1000)
// FWP_DATA_TYPE values we use.
const (
fwpEmpty uint32 = 0
fwpUint8 uint32 = 1
fwpUint16 uint32 = 2
fwpUint64 uint32 = 4
)
// FWP_MATCH_TYPE values.
const fwpMatchEqual uint32 = 0
// FWPM_SESSION flags.
const fwpmSessionFlagDynamic uint32 = 0x1
// FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT prevents lower-priority filters in other sublayers,
// notably Windows Defender Firewall's MPSSVC_WF sublayer, which shares our 0xFFFF weight from overriding this PERMIT.
// Without it, a default WDF block at the same sublayer weight can still win arbitration.
const fwpmFilterFlagClearActionRight uint32 = 0x8
// RPC authentication.
// RPC_C_AUTHN_WINNT works on workgroup machines with no domain context
// RPC_C_AUTHN_DEFAULT falls back through a chain that can land on something WFP doesn't accept on a fresh box.
const rpcCAuthnWinNT uint32 = 10
// fwpByteBlob (FWP_BYTE_BLOB). 16 bytes on 64-bit.
type fwpByteBlob struct {
size uint32
_ uint32 // padding
data *uint8
}
// fwpValue0 / FWP_CONDITION_VALUE0 layout. 16 bytes on 64-bit.
// The union is pointer-sized; types <= 32 bits (UINT8/16/32, INT8/16/32, float) live inline in the low bytes
// of `value`, while UINT64/INT64/double and aggregate types are stored *by pointer*, even on 64-bit, where the
// union member is declared as UINT64*. So when populating an FWP_UINT64 condition, pass
// uintptr(unsafe.Pointer(&luidVar)) instead of the LUID inline.
type fwpValue0 struct {
type_ uint32
_ uint32 // padding before union to 8-byte alignment
value uintptr
}
// fwpmDisplayData0 / FWPM_DISPLAY_DATA0. 16 bytes on 64-bit.
type fwpmDisplayData0 struct {
name *uint16
description *uint16
}
// fwpmAction0 / FWPM_ACTION0. 20 bytes; no leading padding because actionType
// is uint32 and GUID's first field is uint32.
type fwpmAction0 struct {
actionType uint32
filterType windows.GUID
}
// fwpmFilterCondition0. 40 bytes on 64-bit.
type fwpmFilterCondition0 struct {
fieldKey windows.GUID // 16
matchType uint32 // 4
_ uint32 // 4 padding
conditionValue fwpValue0 // 16
}
// fwpmFilter0. 200 bytes on 64-bit.
type fwpmFilter0 struct {
filterKey windows.GUID
displayData fwpmDisplayData0
flags uint32
_ uint32 // padding before *GUID
providerKey *windows.GUID
providerData fwpByteBlob
layerKey windows.GUID
subLayerKey windows.GUID
weight fwpValue0
numFilterConditions uint32
_ uint32 // padding before pointer
filterCondition *fwpmFilterCondition0
action fwpmAction0
_ [4]byte // layout correction
providerContextKey windows.GUID
reserved *windows.GUID
filterID uint64
effectiveWeight fwpValue0
}
// fwpmSublayer0. 72 bytes on 64-bit.
type fwpmSublayer0 struct {
subLayerKey windows.GUID
displayData fwpmDisplayData0
flags uint32
_ uint32 // padding before *GUID
providerKey *windows.GUID
providerData fwpByteBlob
weight uint16
_ [6]byte // padding to 72 bytes
}
// fwpmSession0. 72 bytes on 64-bit.
type fwpmSession0 struct {
sessionKey windows.GUID
displayData fwpmDisplayData0
flags uint32
txnWaitTimeoutInMSec uint32
processId uint32
_ uint32 // padding before *SID
sid *windows.SID
username *uint16
kernelMode uint8
_ [7]byte // tail padding
}
// fwpuclnt.dll bindings. Only the calls we use.
var (
modFwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
procFwpmEngineOpen0 = modFwpuclnt.NewProc("FwpmEngineOpen0")
procFwpmEngineClose0 = modFwpuclnt.NewProc("FwpmEngineClose0")
procFwpmSubLayerAdd0 = modFwpuclnt.NewProc("FwpmSubLayerAdd0")
procFwpmFilterAdd0 = modFwpuclnt.NewProc("FwpmFilterAdd0")
)
// Session holds the WFP engine handle for a single bypass operation. The handle owns a dynamic session:
// when it is closed, every WFP object added during the session (sublayer + filters) is automatically deleted by
// Windows. That gives us correct cleanup even if the host process is killed hard between Permit* and Close.
type Session struct {
engine uintptr
}
// Close releases the engine handle. Windows deletes every dynamic object (sublayer + filters) the session installed.
// Safe to call on a nil receiver.
func (s *Session) Close() {
if s == nil || s.engine == 0 {
return
}
procFwpmEngineClose0.Call(s.engine)
s.engine = 0
}
// PermitInterface installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to the given network
// interface LUID. Inbound traffic on that interface bypasses Windows Defender Firewall.
func PermitInterface(luid uint64) (*Session, error) {
s, sublayerKey, err := newSession()
if err != nil {
return nil, err
}
if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, luid); err != nil {
s.Close()
return nil, fmt.Errorf("add v4 filter: %w", err)
}
if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, luid); err != nil {
s.Close()
return nil, fmt.Errorf("add v6 filter: %w", err)
}
return s, nil
}
// PermitUDPPort installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to UDP traffic with the
// given local port. Inbound UDP to that port on any interface bypasses Windows Defender Firewall.
func PermitUDPPort(port uint16) (*Session, error) {
s, sublayerKey, err := newSession()
if err != nil {
return nil, err
}
if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, port); err != nil {
s.Close()
return nil, fmt.Errorf("add v4 filter: %w", err)
}
if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, port); err != nil {
s.Close()
return nil, fmt.Errorf("add v6 filter: %w", err)
}
return s, nil
}
func newSession() (*Session, windows.GUID, error) {
engine, err := openDynamicEngine()
if err != nil {
return nil, windows.GUID{}, err
}
sublayerKey, err := registerSublayer(engine)
if err != nil {
procFwpmEngineClose0.Call(engine)
return nil, windows.GUID{}, err
}
return &Session{engine: engine}, sublayerKey, nil
}
func openDynamicEngine() (uintptr, error) {
session := fwpmSession0{flags: fwpmSessionFlagDynamic}
var engine uintptr
r1, _, _ := procFwpmEngineOpen0.Call(
0, // serverName == NULL (local)
uintptr(rpcCAuthnWinNT),
0, // authIdentity == NULL
uintptr(unsafe.Pointer(&session)),
uintptr(unsafe.Pointer(&engine)),
)
if r1 != 0 {
return 0, fmt.Errorf("FwpmEngineOpen0: 0x%x", r1)
}
return engine, nil
}
// registerSublayer adds a session-scoped sublayer with a freshly generated GUID, weight 0xFFFF so its filters arbitrate
// above WDF's default sublayer. The sublayer is dynamic (no PERSISTENT flag) and goes away when the engine handle closes.
func registerSublayer(engine uintptr) (windows.GUID, error) {
key, err := windows.GenerateGUID()
if err != nil {
return windows.GUID{}, fmt.Errorf("GenerateGUID for sublayer: %w", err)
}
name, _ := windows.UTF16PtrFromString("Nebula WDF bypass sublayer")
desc, _ := windows.UTF16PtrFromString("Permit filters bypassing Windows Defender Firewall")
sl := fwpmSublayer0{
subLayerKey: key,
displayData: fwpmDisplayData0{name: name, description: desc},
weight: 0xFFFF,
}
r1, _, _ := procFwpmSubLayerAdd0.Call(
engine,
uintptr(unsafe.Pointer(&sl)),
0, // sd == NULL
)
if r1 != 0 {
return windows.GUID{}, fmt.Errorf("FwpmSubLayerAdd0: 0x%x", r1)
}
return key, nil
}
func addInterfaceFilter(engine uintptr, sublayerKey, layer windows.GUID, luid uint64) error {
name, _ := windows.UTF16PtrFromString("Nebula allow interface inbound")
desc, _ := windows.UTF16PtrFromString("Permits inbound traffic on a nebula interface")
// luid must remain addressable through the syscall -- FWP_UINT64 is stored
// by pointer in the FWP_VALUE0 union.
cond := fwpmFilterCondition0{
fieldKey: fwpmConditionIPLocalInterface,
matchType: fwpMatchEqual,
conditionValue: fwpValue0{
type_: fwpUint64,
value: uintptr(unsafe.Pointer(&luid)),
},
}
filter := fwpmFilter0{
// filterKey left zero: WFP assigns one when the filter is added.
displayData: fwpmDisplayData0{name: name, description: desc},
flags: fwpmFilterFlagClearActionRight,
layerKey: layer,
subLayerKey: sublayerKey,
weight: fwpValue0{type_: fwpUint8, value: uintptr(15)},
numFilterConditions: 1,
filterCondition: &cond,
action: fwpmAction0{actionType: fwpActionPermit},
}
r1, _, _ := procFwpmFilterAdd0.Call(
engine,
uintptr(unsafe.Pointer(&filter)),
0, // sd == NULL
0, // id == NULL
)
if r1 != 0 {
return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1)
}
return nil
}
// addUDPPortFilter installs a PERMIT filter that matches (IP_PROTOCOL == UDP) AND (IP_LOCAL_PORT == port).
// FWP_UINT8 and FWP_UINT16 are <= 32 bits so they live inline in the FWP_VALUE0 union.
func addUDPPortFilter(engine uintptr, sublayerKey, layer windows.GUID, port uint16) error {
name, _ := windows.UTF16PtrFromString("Nebula allow UDP port inbound")
desc, _ := windows.UTF16PtrFromString("Permits inbound UDP to a nebula listener port")
conds := [2]fwpmFilterCondition0{
{
fieldKey: fwpmConditionIPProtocol,
matchType: fwpMatchEqual,
conditionValue: fwpValue0{
type_: fwpUint8,
value: uintptr(ipprotoUDP),
},
},
{
fieldKey: fwpmConditionIPLocalPort,
matchType: fwpMatchEqual,
conditionValue: fwpValue0{
type_: fwpUint16,
value: uintptr(port),
},
},
}
filter := fwpmFilter0{
displayData: fwpmDisplayData0{name: name, description: desc},
flags: fwpmFilterFlagClearActionRight,
layerKey: layer,
subLayerKey: sublayerKey,
weight: fwpValue0{type_: fwpUint8, value: uintptr(15)},
numFilterConditions: 2,
filterCondition: &conds[0],
action: fwpmAction0{actionType: fwpActionPermit},
}
r1, _, _ := procFwpmFilterAdd0.Call(
engine,
uintptr(unsafe.Pointer(&filter)),
0,
0,
)
if r1 != 0 {
return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1)
}
return nil
}