mirror of
https://github.com/slackhq/nebula.git
synced 2026-06-30 18:40:29 +02:00
Merge remote-tracking branch 'origin/master' into multiport
This commit is contained in:
@@ -0,0 +1,113 @@
|
||||
name: Code-sign Windows binaries
|
||||
description: >
|
||||
Sign every .exe under a given path in place via the DefinedNet code-signer
|
||||
Lambda. If `role` or `bucket` is empty, logs a notice and skips signing so
|
||||
forks and dev branches without AWS access still produce usable builds.
|
||||
|
||||
inputs:
|
||||
path:
|
||||
description: "Directory whose .exe files should be signed in place"
|
||||
required: true
|
||||
role:
|
||||
description: "IAM role ARN to assume via OIDC; empty disables signing"
|
||||
required: false
|
||||
default: ""
|
||||
bucket:
|
||||
description: "S3 staging bucket the code-signer Lambda reads from; empty disables signing"
|
||||
required: false
|
||||
default: ""
|
||||
region:
|
||||
description: "AWS region for the role and Lambda"
|
||||
required: false
|
||||
default: "us-east-2"
|
||||
function-name:
|
||||
description: "Code-signer Lambda function name"
|
||||
required: false
|
||||
default: "code-signer"
|
||||
key-prefix:
|
||||
description: "S3 key prefix the caller is authorized to write under"
|
||||
required: false
|
||||
default: "code-signing/slackhq/nebula"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Skip notice
|
||||
if: inputs.role == '' || inputs.bucket == ''
|
||||
shell: sh
|
||||
run: echo "::notice::code-signer role or bucket not set; skipping code signing."
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: inputs.role != '' && inputs.bucket != ''
|
||||
uses: aws-actions/configure-aws-credentials@v6
|
||||
with:
|
||||
role-to-assume: ${{ inputs.role }}
|
||||
aws-region: ${{ inputs.region }}
|
||||
# Default is 12 retries to ride out IAM trust-policy propagation; once
|
||||
# the role is stable we want a real misconfiguration to fail fast.
|
||||
retry-max-attempts: 5
|
||||
|
||||
- name: Sign .exe files
|
||||
if: inputs.role != '' && inputs.bucket != ''
|
||||
shell: sh
|
||||
env:
|
||||
SIGN_PATH: ${{ inputs.path }}
|
||||
BUCKET: ${{ inputs.bucket }}
|
||||
FUNCTION_NAME: ${{ inputs.function-name }}
|
||||
KEY_PREFIX: ${{ inputs.key-prefix }}
|
||||
run: |
|
||||
set -eu
|
||||
RUN="${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
|
||||
|
||||
find "$SIGN_PATH" -name '*.exe' -print | while read -r path
|
||||
do
|
||||
rel=${path#"$SIGN_PATH"/}
|
||||
file=$(basename "$path")
|
||||
name=${file%.exe}
|
||||
prefix="${KEY_PREFIX}/${RUN}"
|
||||
src="${prefix}/unsigned/${rel}"
|
||||
dst="${prefix}/signed/${rel}"
|
||||
|
||||
echo "::group::Sign ${rel}"
|
||||
echo "Uploading unsigned to s3://${BUCKET}/${src}"
|
||||
aws s3 cp --no-progress "$path" "s3://${BUCKET}/${src}" >/dev/null
|
||||
|
||||
echo "Invoking ${FUNCTION_NAME} Lambda"
|
||||
payload=$(jq -nc \
|
||||
--arg s "$src" \
|
||||
--arg d "$dst" \
|
||||
--arg p "$name" \
|
||||
'{source_key: $s, dest_key: $d, program_name: $p}')
|
||||
meta=$(aws lambda invoke \
|
||||
--function-name "$FUNCTION_NAME" \
|
||||
--cli-binary-format raw-in-base64-out \
|
||||
--payload "$payload" \
|
||||
--output json \
|
||||
/tmp/sign-resp.json)
|
||||
if echo "$meta" | jq -e '.FunctionError != null' >/dev/null
|
||||
then
|
||||
echo "::endgroup::"
|
||||
echo "::error::code-signer Lambda failed for ${rel}"
|
||||
cat /tmp/sign-resp.json >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Downloading signed back to ${path}"
|
||||
aws s3 cp --no-progress "s3://${BUCKET}/${dst}" "$path" >/dev/null
|
||||
|
||||
aws s3 rm "s3://${BUCKET}/${src}" >/dev/null 2>&1 || true
|
||||
aws s3 rm "s3://${BUCKET}/${dst}" >/dev/null 2>&1 || true
|
||||
|
||||
# Sanity-check the bytes we got back actually carry an Authenticode
|
||||
# signature that this machine can validate end to end.
|
||||
status=$(powershell -NoProfile -Command "(Get-AuthenticodeSignature -FilePath '$path').Status" | tr -d '\r')
|
||||
if [ "$status" != "Valid" ]
|
||||
then
|
||||
echo "::endgroup::"
|
||||
echo "::error::${rel} signature status: ${status} (expected Valid)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Signed ${rel} (sha256=$(jq -r '.sha256' /tmp/sign-resp.json), status=${status})"
|
||||
echo "::endgroup::"
|
||||
done
|
||||
@@ -1,34 +0,0 @@
|
||||
name: gofmt
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/workflows/gofmt.yml'
|
||||
- '**.go'
|
||||
jobs:
|
||||
|
||||
gofmt:
|
||||
name: Run gofmt
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Install goimports
|
||||
run: |
|
||||
go install golang.org/x/tools/cmd/goimports@latest
|
||||
|
||||
- name: gofmt
|
||||
run: |
|
||||
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ]
|
||||
then
|
||||
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d
|
||||
exit 1
|
||||
fi
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
mv build/*.tar.gz release
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: linux-latest
|
||||
path: release
|
||||
@@ -32,6 +32,9 @@ jobs:
|
||||
build-windows:
|
||||
name: Build Windows
|
||||
runs-on: windows-latest
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
@@ -54,8 +57,15 @@ jobs:
|
||||
mkdir build\dist\windows
|
||||
mv dist\windows\wintun build\dist\windows\
|
||||
|
||||
- name: Code-sign
|
||||
uses: ./.github/actions/code-sign
|
||||
with:
|
||||
path: build
|
||||
role: ${{ secrets.DEFINED_CODE_SIGNER_ROLE }}
|
||||
bucket: ${{ secrets.DEFINED_CODE_SIGNER_BUCKET }}
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: windows-latest
|
||||
path: build
|
||||
@@ -75,7 +85,7 @@ jobs:
|
||||
|
||||
- name: Import certificates
|
||||
if: env.HAS_SIGNING_CREDS == 'true'
|
||||
uses: Apple-Actions/import-codesign-certs@v6
|
||||
uses: Apple-Actions/import-codesign-certs@v7
|
||||
with:
|
||||
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
|
||||
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
|
||||
@@ -104,7 +114,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: darwin-latest
|
||||
path: ./release/*
|
||||
@@ -128,21 +138,21 @@ jobs:
|
||||
|
||||
- name: Download artifacts
|
||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||
uses: actions/download-artifact@v7
|
||||
uses: actions/download-artifact@v8
|
||||
with:
|
||||
name: linux-latest
|
||||
path: artifacts
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
username: ${{ vars.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v4
|
||||
|
||||
- name: Build and push images
|
||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||
@@ -163,7 +173,7 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v7
|
||||
uses: actions/download-artifact@v8
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
|
||||
@@ -14,10 +14,18 @@ on:
|
||||
- 'go.sum'
|
||||
jobs:
|
||||
|
||||
smoke-extra:
|
||||
smoke-extra-libvirt:
|
||||
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
|
||||
name: Run extra smoke tests
|
||||
name: ${{ matrix.target }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
target:
|
||||
- freebsd-amd64
|
||||
- openbsd-amd64
|
||||
- netbsd-amd64
|
||||
- linux-amd64-ipv6disable
|
||||
env:
|
||||
VAGRANT_DEFAULT_PROVIDER: libvirt
|
||||
steps:
|
||||
@@ -40,28 +48,85 @@ jobs:
|
||||
sudo chmod 666 /var/run/libvirt/libvirt-sock
|
||||
vagrant plugin install vagrant-libvirt
|
||||
|
||||
- name: freebsd-amd64
|
||||
run: make smoke-vagrant/freebsd-amd64
|
||||
- name: ${{ matrix.target }}
|
||||
run: make smoke-vagrant/${{ matrix.target }}
|
||||
|
||||
- name: openbsd-amd64
|
||||
run: make smoke-vagrant/openbsd-amd64
|
||||
timeout-minutes: 30
|
||||
|
||||
- name: netbsd-amd64
|
||||
run: make smoke-vagrant/netbsd-amd64
|
||||
# linux-386 needs VirtualBox, which conflicts with KVM/libvirt -- isolated job.
|
||||
smoke-extra-virtualbox:
|
||||
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
|
||||
name: linux-386
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VAGRANT_DEFAULT_PROVIDER: virtualbox
|
||||
steps:
|
||||
|
||||
- name: linux-amd64-ipv6disable
|
||||
run: make smoke-vagrant/linux-amd64-ipv6disable
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
# linux-386 runs last because it requires disabling KVM to use VirtualBox,
|
||||
# which prevents libvirt (used by the other tests) from working after this point.
|
||||
- name: install virtualbox for i386 test
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: add hashicorp source
|
||||
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
|
||||
|
||||
- name: install vagrant and virtualbox
|
||||
run: |
|
||||
sudo apt-get install -y virtualbox
|
||||
sudo apt-get update && sudo apt-get install -y vagrant virtualbox
|
||||
sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true
|
||||
|
||||
- name: linux-386
|
||||
env:
|
||||
VAGRANT_DEFAULT_PROVIDER: virtualbox
|
||||
run: make smoke-vagrant/linux-386
|
||||
|
||||
timeout-minutes: 30
|
||||
|
||||
smoke-windows:
|
||||
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
|
||||
name: Run windows smoke test
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
# WSL2 + Ubuntu so the smoke can run a real linux peer with its own
|
||||
# netns. iputils-ping is needed for the in-WSL ping check. WSL1 has no
|
||||
# real kernel and would lack /dev/net/tun, so we have to force WSL2.
|
||||
- uses: Vampire/setup-wsl@v3
|
||||
with:
|
||||
distribution: Ubuntu-24.04
|
||||
additional-packages: iputils-ping iproute2
|
||||
|
||||
# Vampire/setup-wsl provisions WSL1 even when the WSL2 platform is present.
|
||||
# Convert the distro to WSL2 explicitly before we try to use /dev/net/tun.
|
||||
- name: convert distro to WSL2
|
||||
shell: pwsh
|
||||
run: |
|
||||
wsl --set-version Ubuntu-24.04 2
|
||||
wsl --shutdown
|
||||
wsl --list --verbose
|
||||
|
||||
- name: build windows nebula
|
||||
run: make bin-windows
|
||||
|
||||
- name: build linux nebula for WSL
|
||||
shell: bash
|
||||
env:
|
||||
GOOS: linux
|
||||
GOARCH: amd64
|
||||
run: |
|
||||
mkdir -p build/linux-amd64
|
||||
go build -o build/linux-amd64/nebula ./cmd/nebula
|
||||
|
||||
- name: run smoke-windows
|
||||
shell: pwsh
|
||||
working-directory: ./.github/workflows/smoke
|
||||
run: ./smoke-windows.ps1
|
||||
|
||||
timeout-minutes: 15
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
#!/usr/bin/env pwsh
|
||||
# Windows smoke test for the nebula tun + UDP + NLM code paths.
|
||||
#
|
||||
# Topology:
|
||||
# - lighthouse runs natively on the Windows host (wintun + windows UDP)
|
||||
# - peer runs inside WSL2 (Linux build of nebula, /dev/net/tun)
|
||||
#
|
||||
# WSL2 gives us a real netns boundary so the loopback fast-path on Windows
|
||||
# does not short-circuit the overlay -- when WSL pings the lighthouse VPN IP,
|
||||
# Linux has no idea that IP is local to the Windows host, so the packet is
|
||||
# forced through nebula. Same in reverse.
|
||||
|
||||
$ErrorActionPreference = 'Stop'
|
||||
|
||||
# wsl.exe emits UTF-16 LE by default which PowerShell reads as bytes, mangling
|
||||
# every captured string. WSL_UTF8 makes wsl.exe emit UTF-8 instead.
|
||||
$env:WSL_UTF8 = '1'
|
||||
|
||||
$RepoRoot = Resolve-Path "$PSScriptRoot\..\..\.."
|
||||
$Nebula = Join-Path $RepoRoot 'nebula.exe'
|
||||
$NebulaCert = Join-Path $RepoRoot 'nebula-cert.exe'
|
||||
$NebulaLinux = Join-Path $RepoRoot 'build\linux-amd64\nebula'
|
||||
|
||||
if (-not (Test-Path $Nebula)) { throw "missing $Nebula; run 'make bin-windows' first" }
|
||||
if (-not (Test-Path $NebulaCert)) { throw "missing $NebulaCert; run 'make bin-windows' first" }
|
||||
if (-not (Test-Path $NebulaLinux)) { throw "missing $NebulaLinux; build the linux nebula first" }
|
||||
|
||||
# Matches the distro installed by Vampire/setup-wsl in smoke-extra.yml.
|
||||
$Distro = 'Ubuntu-24.04'
|
||||
$listed = (wsl --list --quiet 2>$null) -join "`n"
|
||||
if ($listed -notmatch [regex]::Escape($Distro)) {
|
||||
throw "WSL distro $Distro not registered. Got: $listed"
|
||||
}
|
||||
Write-Host "Using WSL distro: $Distro"
|
||||
|
||||
# Windows host as seen from inside WSL: WSL's default-route gateway. We extract
|
||||
# it with a regex rather than awk fields so PowerShell does not eat any '$N'
|
||||
# tokens, and tabs/double-spaces in `ip route` output do not confuse a cut.
|
||||
$ipCmd = 'ip route show default | grep -oE "([0-9]+\.){3}[0-9]+" | head -1'
|
||||
$WindowsIp = (wsl -d $Distro -- bash -c $ipCmd).Trim()
|
||||
if (-not $WindowsIp) { throw "could not determine Windows host IP from WSL" }
|
||||
Write-Host "Windows host IP from WSL: $WindowsIp"
|
||||
|
||||
$WorkDir = Join-Path $env:TEMP 'nebula-smoke-windows'
|
||||
if (Test-Path $WorkDir) { Remove-Item -Recurse -Force $WorkDir }
|
||||
New-Item -ItemType Directory -Path $WorkDir | Out-Null
|
||||
|
||||
$WslDir = '/tmp/nebula-smoke'
|
||||
wsl -d $Distro -- bash -c "rm -rf $WslDir && mkdir -p $WslDir" | Out-Null
|
||||
|
||||
$DevName = 'nebula-smoke'
|
||||
$Ip1 = '192.168.241.1'
|
||||
$Ip2 = '192.168.241.2'
|
||||
$Port = 4242
|
||||
|
||||
& $NebulaCert ca -name 'smoke-ca' -out-crt "$WorkDir\ca.crt" -out-key "$WorkDir\ca.key"
|
||||
if ($LASTEXITCODE -ne 0) { throw "nebula-cert ca failed (exit $LASTEXITCODE)" }
|
||||
|
||||
& $NebulaCert sign -name 'lighthouse' -networks "$Ip1/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\lighthouse.crt" -out-key "$WorkDir\lighthouse.key"
|
||||
if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign lighthouse failed (exit $LASTEXITCODE)" }
|
||||
|
||||
& $NebulaCert sign -name 'peer' -networks "$Ip2/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\peer.crt" -out-key "$WorkDir\peer.key"
|
||||
if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign peer failed (exit $LASTEXITCODE)" }
|
||||
|
||||
# Windows lighthouse config.
|
||||
@"
|
||||
pki:
|
||||
ca: $WorkDir\ca.crt
|
||||
cert: $WorkDir\lighthouse.crt
|
||||
key: $WorkDir\lighthouse.key
|
||||
static_host_map: {}
|
||||
lighthouse:
|
||||
am_lighthouse: true
|
||||
interval: 60
|
||||
hosts: []
|
||||
listen:
|
||||
host: 0.0.0.0
|
||||
port: $Port
|
||||
tun:
|
||||
disabled: false
|
||||
dev: $DevName
|
||||
drop_local_broadcast: false
|
||||
drop_multicast: false
|
||||
tx_queue: 500
|
||||
mtu: 1300
|
||||
network_category: private
|
||||
logging:
|
||||
level: info
|
||||
format: text
|
||||
firewall:
|
||||
outbound_action: drop
|
||||
inbound_action: drop
|
||||
conntrack:
|
||||
tcp_timeout: 12m
|
||||
udp_timeout: 3m
|
||||
default_timeout: 10m
|
||||
outbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
inbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
"@ | Out-File -FilePath "$WorkDir\lighthouse.yml" -Encoding utf8
|
||||
|
||||
# WSL peer config (paths are POSIX, deliberately).
|
||||
@"
|
||||
pki:
|
||||
ca: $WslDir/ca.crt
|
||||
cert: $WslDir/peer.crt
|
||||
key: $WslDir/peer.key
|
||||
static_host_map:
|
||||
"${Ip1}": ["${WindowsIp}:$Port"]
|
||||
lighthouse:
|
||||
am_lighthouse: false
|
||||
interval: 60
|
||||
hosts:
|
||||
- "${Ip1}"
|
||||
listen:
|
||||
host: 0.0.0.0
|
||||
port: 0
|
||||
tun:
|
||||
disabled: false
|
||||
dev: nebula1
|
||||
drop_local_broadcast: false
|
||||
drop_multicast: false
|
||||
tx_queue: 500
|
||||
mtu: 1300
|
||||
logging:
|
||||
level: info
|
||||
format: text
|
||||
firewall:
|
||||
outbound_action: drop
|
||||
inbound_action: drop
|
||||
conntrack:
|
||||
tcp_timeout: 12m
|
||||
udp_timeout: 3m
|
||||
default_timeout: 10m
|
||||
outbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
inbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
"@ | Out-File -FilePath "$WorkDir\peer.yml" -Encoding utf8
|
||||
|
||||
# Stage WSL artifacts. Convert Windows paths to WSL paths ourselves rather than
|
||||
# calling `wslpath`, because PowerShell's argument-passing to external EXEs
|
||||
# strips backslashes from path arguments in ways that are hard to escape around.
|
||||
function ConvertTo-WslPath {
|
||||
param([string]$WindowsPath)
|
||||
if ($WindowsPath -notmatch '^([A-Za-z]):\\(.*)$') {
|
||||
throw "cannot convert path to WSL: $WindowsPath"
|
||||
}
|
||||
return "/mnt/$($matches[1].ToLower())/$($matches[2].Replace('\','/'))"
|
||||
}
|
||||
|
||||
$WslWorkDir = ConvertTo-WslPath $WorkDir
|
||||
$WslNebulaPath = ConvertTo-WslPath $NebulaLinux
|
||||
wsl -d $Distro -- bash -c "cp '$WslWorkDir/ca.crt' '$WslWorkDir/peer.crt' '$WslWorkDir/peer.key' '$WslWorkDir/peer.yml' $WslDir/ && cp '$WslNebulaPath' $WslDir/nebula && chmod +x $WslDir/nebula"
|
||||
|
||||
# Make sure WSL has tun support and /dev/net/tun is usable before starting
|
||||
# nebula. Diagnostics first so a fail here points at the real problem (e.g.
|
||||
# WSL1 distros do not have a real kernel and will not have tun).
|
||||
Write-Host '=== WSL diagnostic ==='
|
||||
wsl --version 2>&1 | Out-Host
|
||||
wsl --list --verbose 2>&1 | Out-Host
|
||||
wsl -d $Distro -u root -- uname -a | Out-Host
|
||||
wsl -d $Distro -u root -- bash -c "modprobe tun 2>&1 || true; mkdir -p /dev/net; [ -c /dev/net/tun ] || mknod /dev/net/tun c 10 200; chmod 600 /dev/net/tun; ls -l /dev/net/tun"
|
||||
if ($LASTEXITCODE -ne 0) { throw "failed to prepare /dev/net/tun in WSL (TUN support missing?)" }
|
||||
|
||||
# Deliberately no New-NetFirewallRule calls here -- nebula's windows_bypass_wdf
|
||||
# feature is supposed to install WFP permit filters that let inbound traffic
|
||||
# through Windows Defender Firewall on its own. If this smoke regresses, that
|
||||
# feature regressed.
|
||||
|
||||
$lhOut = Join-Path $WorkDir 'lighthouse.out.log'
|
||||
$lhErr = Join-Path $WorkDir 'lighthouse.err.log'
|
||||
$lhProc = Start-Process -FilePath $Nebula -ArgumentList @('-config', "$WorkDir\lighthouse.yml") `
|
||||
-PassThru -NoNewWindow `
|
||||
-RedirectStandardOutput $lhOut `
|
||||
-RedirectStandardError $lhErr
|
||||
|
||||
# Run nebula in WSL as root with no sudo + no shell wrapper. PowerShell's
|
||||
# Start-Process arg quoting mangles `bash -c "..."` strings that contain
|
||||
# spaces/redirections, so we skip bash entirely and let Start-Process do the
|
||||
# stdout/stderr capture itself.
|
||||
$peerOut = Join-Path $WorkDir 'peer.out.log'
|
||||
$peerErr = Join-Path $WorkDir 'peer.err.log'
|
||||
$peerProc = Start-Process -FilePath 'wsl' `
|
||||
-ArgumentList @('-d', $Distro, '-u', 'root', '--', "$WslDir/nebula", '-config', "$WslDir/peer.yml") `
|
||||
-PassThru -NoNewWindow `
|
||||
-RedirectStandardOutput $peerOut `
|
||||
-RedirectStandardError $peerErr
|
||||
|
||||
function Wait-Until {
|
||||
param([scriptblock]$Predicate, [int]$TimeoutSec, [string]$What)
|
||||
$deadline = (Get-Date).AddSeconds($TimeoutSec)
|
||||
while ((Get-Date) -lt $deadline) {
|
||||
if (& $Predicate) { return }
|
||||
Start-Sleep -Milliseconds 500
|
||||
}
|
||||
throw "timed out waiting for: $What"
|
||||
}
|
||||
|
||||
try {
|
||||
Wait-Until -TimeoutSec 30 -What "windows wintun adapter $DevName with NetworkCategory=Private" -Predicate {
|
||||
if ($lhProc.HasExited) { throw "lighthouse exited (code $($lhProc.ExitCode)) before tun was ready" }
|
||||
$p = Get-NetConnectionProfile -InterfaceAlias $DevName -ErrorAction SilentlyContinue
|
||||
$p -and ("$($p.NetworkCategory)" -ieq 'Private')
|
||||
}
|
||||
Write-Host "OK: $DevName NetworkCategory=Private"
|
||||
|
||||
Wait-Until -TimeoutSec 30 -What "WSL nebula1 with $Ip2" -Predicate {
|
||||
if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before tun was ready" }
|
||||
$r = wsl -d $Distro -u root -- bash -c "ip -o addr show nebula1 2>/dev/null | grep -q 'inet $Ip2' && echo yes"
|
||||
("$r").Trim() -eq 'yes'
|
||||
}
|
||||
Write-Host "OK: WSL nebula1 has $Ip2"
|
||||
|
||||
Wait-Until -TimeoutSec 30 -What "ping from WSL peer to windows lighthouse ($Ip1)" -Predicate {
|
||||
if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before ping succeeded" }
|
||||
$r = wsl -d $Distro -u root -- bash -c "ping -c1 -W1 $Ip1 >/dev/null 2>&1 && echo OK"
|
||||
("$r").Trim() -eq 'OK'
|
||||
}
|
||||
Write-Host "OK: WSL peer -> windows lighthouse"
|
||||
|
||||
Wait-Until -TimeoutSec 30 -What "ping from windows lighthouse to WSL peer ($Ip2)" -Predicate {
|
||||
$null = & ping.exe -n 1 -w 1000 $Ip2
|
||||
$LASTEXITCODE -eq 0
|
||||
}
|
||||
Write-Host "OK: windows lighthouse -> WSL peer"
|
||||
|
||||
Write-Host ''
|
||||
Write-Host 'All smoke checks passed.'
|
||||
}
|
||||
catch {
|
||||
Write-Host ''
|
||||
Write-Host '=== lighthouse stdout ==='
|
||||
Get-Content $lhOut -ErrorAction SilentlyContinue | Out-Host
|
||||
Write-Host '=== lighthouse stderr ==='
|
||||
Get-Content $lhErr -ErrorAction SilentlyContinue | Out-Host
|
||||
Write-Host '=== peer stdout ==='
|
||||
Get-Content $peerOut -ErrorAction SilentlyContinue | Out-Host
|
||||
Write-Host '=== peer stderr ==='
|
||||
Get-Content $peerErr -ErrorAction SilentlyContinue | Out-Host
|
||||
Write-Host '=== nebula WFP filters ==='
|
||||
# Dump nebula-installed filters so we can verify they got registered with
|
||||
# the conditions we expect.
|
||||
$wfpDump = Join-Path $WorkDir 'wfp.xml'
|
||||
netsh wfp show filters file=$wfpDump 2>&1 | Out-Null
|
||||
if (Test-Path $wfpDump) {
|
||||
Select-String -Path $wfpDump -Pattern 'Nebula' -Context 0,80 -ErrorAction SilentlyContinue | Out-Host
|
||||
}
|
||||
throw
|
||||
}
|
||||
finally {
|
||||
if (-not $lhProc.HasExited) {
|
||||
Stop-Process -Id $lhProc.Id -Force -ErrorAction SilentlyContinue
|
||||
$lhProc.WaitForExit(5000) | Out-Null
|
||||
}
|
||||
wsl -d $Distro -u root -- bash -c "pkill -f $WslDir/nebula 2>/dev/null; true" | Out-Null
|
||||
# pkill returns 1 when no match and wsl propagates that; the smoke is done
|
||||
# so we don't want it to leak into the script's exit code.
|
||||
$global:LASTEXITCODE = 0
|
||||
if ($peerProc -and -not $peerProc.HasExited) {
|
||||
Stop-Process -Id $peerProc.Id -Force -ErrorAction SilentlyContinue
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- mode: ruby -*-
|
||||
# vi: set ft=ruby :
|
||||
Vagrant.configure("2") do |config|
|
||||
config.vm.box = "generic/netbsd9"
|
||||
config.vm.box = "DefinedNet/netbsd10"
|
||||
|
||||
config.vm.synced_folder "../build", "/nebula", type: "rsync"
|
||||
end
|
||||
|
||||
+95
-77
@@ -13,8 +13,8 @@ on:
|
||||
- 'go.sum'
|
||||
jobs:
|
||||
|
||||
test-linux:
|
||||
name: Build all and test on ubuntu-linux
|
||||
static:
|
||||
name: Static checks
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
@@ -25,8 +25,16 @@ jobs:
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
run: make all
|
||||
- name: Install goimports
|
||||
run: go install golang.org/x/tools/cmd/goimports@latest
|
||||
|
||||
- name: gofmt
|
||||
run: |
|
||||
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ]
|
||||
then
|
||||
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Vet
|
||||
run: make vet
|
||||
@@ -36,66 +44,38 @@ jobs:
|
||||
with:
|
||||
version: v2.5
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
|
||||
- name: End 2 end
|
||||
run: make e2evv
|
||||
|
||||
- name: Build test mobile
|
||||
run: make build-test-mobile
|
||||
|
||||
- uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: e2e packet flow linux-latest
|
||||
path: e2e/mermaid/linux-latest
|
||||
if-no-files-found: warn
|
||||
|
||||
test-linux-boringcrypto:
|
||||
name: Build and test on linux with boringcrypto
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
run: make bin-boringcrypto
|
||||
|
||||
- name: Test
|
||||
run: make test-boringcrypto
|
||||
|
||||
- name: End 2 end
|
||||
run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
|
||||
|
||||
test-linux-pkcs11:
|
||||
name: Build and test on linux with pkcs11
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
run: make bin-pkcs11
|
||||
|
||||
- name: Test
|
||||
run: make test-pkcs11
|
||||
|
||||
test:
|
||||
name: Build and test on ${{ matrix.os }}
|
||||
name: Test ${{ matrix.name }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest]
|
||||
include:
|
||||
- name: linux
|
||||
os: ubuntu-latest
|
||||
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
|
||||
test-cmd: make test
|
||||
e2e-cmd: make e2evv
|
||||
- name: linux-boringcrypto
|
||||
os: ubuntu-latest
|
||||
build-cmd: make bin-boringcrypto
|
||||
test-cmd: make test-boringcrypto
|
||||
e2e-cmd: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
|
||||
- name: linux-pkcs11
|
||||
os: ubuntu-latest
|
||||
build-cmd: make bin-pkcs11
|
||||
test-cmd: make test-pkcs11
|
||||
e2e-cmd: ''
|
||||
- name: macos
|
||||
os: macos-latest
|
||||
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
|
||||
test-cmd: make test
|
||||
e2e-cmd: make e2evv
|
||||
- name: windows
|
||||
os: windows-latest
|
||||
build-cmd: go build ./cmd/nebula ./cmd/nebula-cert
|
||||
test-cmd: make test
|
||||
e2e-cmd: make e2evv
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
@@ -105,28 +85,66 @@ jobs:
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build nebula
|
||||
run: go build ./cmd/nebula
|
||||
- name: Build
|
||||
run: ${{ matrix.build-cmd }}
|
||||
|
||||
- name: Build nebula-cert
|
||||
run: go build ./cmd/nebula-cert
|
||||
|
||||
- name: Vet
|
||||
run: make vet
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.5
|
||||
- name: Cross-build darwin-amd64
|
||||
if: matrix.name == 'macos'
|
||||
run: GOARCH=amd64 go build -o /tmp/nebula-amd64 ./cmd/nebula && GOARCH=amd64 go build -o /tmp/nebula-cert-amd64 ./cmd/nebula-cert
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
run: ${{ matrix.test-cmd }}
|
||||
|
||||
- name: End 2 end
|
||||
run: make e2evv
|
||||
if: matrix.e2e-cmd != ''
|
||||
run: ${{ matrix.e2e-cmd }}
|
||||
|
||||
- uses: actions/upload-artifact@v6
|
||||
- uses: actions/upload-artifact@v7
|
||||
if: matrix.e2e-cmd != '' && always()
|
||||
with:
|
||||
name: e2e packet flow ${{ matrix.os }}
|
||||
path: e2e/mermaid/${{ matrix.os }}
|
||||
name: e2e packet flow ${{ matrix.name }}
|
||||
path: e2e/mermaid/
|
||||
if-no-files-found: warn
|
||||
|
||||
cross-build:
|
||||
name: Cross-build ${{ matrix.name }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- {name: linux-arm, make-target: all-cross-linux-arm}
|
||||
- {name: linux-mips, make-target: all-cross-linux-mips}
|
||||
- {name: linux-other, make-target: all-cross-linux-other}
|
||||
- {name: freebsd, make-target: all-freebsd}
|
||||
- {name: openbsd, make-target: all-openbsd}
|
||||
- {name: netbsd, make-target: all-netbsd}
|
||||
- {name: windows, make-target: all-cross-windows}
|
||||
- {name: mobile, make-target: build-test-mobile}
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: Build ${{ matrix.name }}
|
||||
run: make -j"$(nproc)" ${{ matrix.make-target }}
|
||||
|
||||
finish:
|
||||
name: CI status
|
||||
if: always()
|
||||
needs: [static, test, cross-build]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- name: Fail if any upstream job failed
|
||||
if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled')
|
||||
run: |
|
||||
echo "upstream results: ${{ toJSON(needs) }}"
|
||||
exit 1
|
||||
|
||||
- name: All upstream jobs passed
|
||||
run: echo "ok"
|
||||
|
||||
@@ -60,6 +60,18 @@ ALL = $(ALL_LINUX) \
|
||||
windows-amd64 \
|
||||
windows-arm64
|
||||
|
||||
# Cross-build shards used by .github/workflows/test.yml — same as ALL_*
|
||||
# but with the arch that has a native CI runner removed, so the cross-build
|
||||
# job is not duplicating coverage the native test jobs already give.
|
||||
ALL_CROSS_LINUX = $(filter-out linux-amd64,$(ALL_LINUX))
|
||||
|
||||
# ALL_CROSS_LINUX further split into family sub-shards so each can run on
|
||||
# its own CI runner in parallel. Union of the three must equal
|
||||
# ALL_CROSS_LINUX; adding a new linux arch goes into the matching family.
|
||||
ALL_CROSS_LINUX_ARM = linux-arm-5 linux-arm-6 linux-arm-7 linux-arm64
|
||||
ALL_CROSS_LINUX_MIPS = linux-mips linux-mipsle linux-mips64 linux-mips64le linux-mips-softfloat
|
||||
ALL_CROSS_LINUX_OTHER = linux-386 linux-ppc64le linux-riscv64 linux-loong64
|
||||
|
||||
e2e:
|
||||
$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
|
||||
|
||||
@@ -82,6 +94,35 @@ DOCKER_BIN = build/linux-amd64/nebula build/linux-amd64/nebula-cert
|
||||
|
||||
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
|
||||
|
||||
all-linux: $(ALL_LINUX:%=build/%/nebula) $(ALL_LINUX:%=build/%/nebula-cert)
|
||||
|
||||
all-freebsd: $(ALL_FREEBSD:%=build/%/nebula) $(ALL_FREEBSD:%=build/%/nebula-cert)
|
||||
|
||||
all-openbsd: $(ALL_OPENBSD:%=build/%/nebula) $(ALL_OPENBSD:%=build/%/nebula-cert)
|
||||
|
||||
all-netbsd: $(ALL_NETBSD:%=build/%/nebula) $(ALL_NETBSD:%=build/%/nebula-cert)
|
||||
|
||||
all-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert build/darwin-arm64/nebula build/darwin-arm64/nebula-cert
|
||||
|
||||
all-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe
|
||||
|
||||
# CI cross-build shards. darwin-arm64 is covered by the native macos-latest
|
||||
# job; windows-amd64 is covered by the native windows-latest job; both are
|
||||
# omitted here to avoid building them a second time. darwin-amd64 stays in
|
||||
# all-cross-darwin because intel mac is only a labeled/master-time native
|
||||
# job, so PRs still need cross-build coverage for it.
|
||||
all-cross-linux: $(ALL_CROSS_LINUX:%=build/%/nebula) $(ALL_CROSS_LINUX:%=build/%/nebula-cert)
|
||||
|
||||
all-cross-linux-arm: $(ALL_CROSS_LINUX_ARM:%=build/%/nebula) $(ALL_CROSS_LINUX_ARM:%=build/%/nebula-cert)
|
||||
|
||||
all-cross-linux-mips: $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula) $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula-cert)
|
||||
|
||||
all-cross-linux-other: $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula) $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula-cert)
|
||||
|
||||
all-cross-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
|
||||
|
||||
all-cross-windows: build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe
|
||||
|
||||
docker: docker/linux-$(shell go env GOARCH)
|
||||
|
||||
release: $(ALL:%=build/nebula-%.tar.gz)
|
||||
@@ -240,5 +281,5 @@ smoke-vagrant/%: bin-docker build/%/nebula
|
||||
cd .github/workflows/smoke/ && ./smoke-vagrant.sh $*
|
||||
|
||||
.FORCE:
|
||||
.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/%
|
||||
.PHONY: all all-linux all-freebsd all-openbsd all-netbsd all-darwin all-windows all-cross-linux all-cross-linux-arm all-cross-linux-mips all-cross-linux-other all-cross-darwin all-cross-windows bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/%
|
||||
.DEFAULT_GOAL := bin
|
||||
|
||||
@@ -217,6 +217,10 @@ func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if signer.Certificate.Curve() != c.Curve() {
|
||||
return nil, ErrCurveMismatch
|
||||
}
|
||||
|
||||
if signer.Certificate.Expired(now) {
|
||||
return nil, ErrRootExpired
|
||||
}
|
||||
|
||||
@@ -654,3 +654,31 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertificateV2_CurveMismatch(t *testing.T) {
|
||||
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
|
||||
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
|
||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||
|
||||
caPem, err := ca.MarshalPEM()
|
||||
require.NoError(t, err)
|
||||
|
||||
caPool := NewCAPool()
|
||||
b, err := caPool.AddCAFromPEM(caPem)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, b)
|
||||
|
||||
// ip is outside the network
|
||||
cIp1 := mustParsePrefixUnmapped("10.0.0.1/24")
|
||||
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1}, nil, []string{"test"})
|
||||
|
||||
fp, _ := c.Fingerprint()
|
||||
_, err = caPool.verify(c, time.Now(), fp, c.Issuer())
|
||||
require.NoError(t, err)
|
||||
//
|
||||
c2 := c.(*certificateV2)
|
||||
c2.curve = Curve_CURVE25519
|
||||
fp, _ = c.Fingerprint()
|
||||
_, err = caPool.verify(c, time.Now(), fp, c.Issuer())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -112,6 +112,9 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
|
||||
}
|
||||
switch c.details.curve {
|
||||
case Curve_CURVE25519:
|
||||
if len(key) != ed25519.PublicKeySize {
|
||||
return false //avoids a panic internal to ed25519
|
||||
}
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||
|
||||
@@ -151,6 +151,9 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
|
||||
|
||||
switch c.curve {
|
||||
case Curve_CURVE25519:
|
||||
if len(key) != ed25519.PublicKeySize {
|
||||
return false //avoids a panic internal to ed25519
|
||||
}
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||
|
||||
@@ -22,6 +22,7 @@ var (
|
||||
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
||||
ErrUnknownVersion = errors.New("certificate version unrecognized")
|
||||
ErrCertPubkeyPresent = errors.New("certificate has unexpected pubkey present")
|
||||
ErrCurveMismatch = errors.New("certificate curve does not match CA")
|
||||
|
||||
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
||||
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
||||
|
||||
+10
-4
@@ -13,6 +13,12 @@ import (
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
// testCertNow is the reference "now" used to derive default before/after times
|
||||
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
|
||||
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
|
||||
// signed with default times can never expire after its CA on a rounding race.
|
||||
var testCertNow = time.Now().Round(time.Second)
|
||||
|
||||
// NewTestCaCert will create a new ca certificate
|
||||
func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
|
||||
var err error
|
||||
@@ -34,10 +40,10 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
|
||||
}
|
||||
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
before = testCertNow.Add(time.Second * -60)
|
||||
}
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
after = testCertNow.Add(time.Second * 60)
|
||||
}
|
||||
|
||||
t := &TBSCertificate{
|
||||
@@ -70,11 +76,11 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
|
||||
// Expiry times are defaulted if you do not pass them in
|
||||
func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
before = testCertNow.Add(time.Second * -60)
|
||||
}
|
||||
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
after = testCertNow.Add(time.Second * 60)
|
||||
}
|
||||
|
||||
if len(networks) == 0 {
|
||||
|
||||
+10
-4
@@ -14,6 +14,12 @@ import (
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
// testCertNow is the reference "now" used to derive default before/after times
|
||||
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
|
||||
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
|
||||
// signed with default times can never expire after its CA on a rounding race.
|
||||
var testCertNow = time.Now().Round(time.Second)
|
||||
|
||||
// NewTestCaCert will create a new ca certificate
|
||||
func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
|
||||
var err error
|
||||
@@ -35,10 +41,10 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
|
||||
}
|
||||
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
before = testCertNow.Add(time.Second * -60)
|
||||
}
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
after = testCertNow.Add(time.Second * 60)
|
||||
}
|
||||
|
||||
t := &cert.TBSCertificate{
|
||||
@@ -71,11 +77,11 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
|
||||
// Expiry times are defaulted if you do not pass them in
|
||||
func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
|
||||
if before.IsZero() {
|
||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
before = testCertNow.Add(time.Second * -60)
|
||||
}
|
||||
|
||||
if after.IsZero() {
|
||||
after = time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
after = testCertNow.Add(time.Second * 60)
|
||||
}
|
||||
|
||||
var pub, priv []byte
|
||||
|
||||
+32
-7
@@ -97,6 +97,19 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// out-key is meaningless under PKCS#11 because the private key never
|
||||
// leaves the HSM; reject it so we never silently accept or claim a
|
||||
// stdout slot for it.
|
||||
outKeySet := false
|
||||
cf.set.Visit(func(f *flag.Flag) {
|
||||
if f.Name == "out-key" {
|
||||
outKeySet = true
|
||||
}
|
||||
})
|
||||
if outKeySet {
|
||||
return newHelpErrorf("cannot set -out-key with -pkcs11")
|
||||
}
|
||||
}
|
||||
if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
|
||||
return err
|
||||
@@ -171,12 +184,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
}
|
||||
}
|
||||
|
||||
var claims ioClaims
|
||||
if err := reserveOutputs(&claims,
|
||||
"out-key", *cf.outKeyPath,
|
||||
"out-crt", *cf.outCertPath,
|
||||
"out-qr", *cf.outQRPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var passphrase []byte
|
||||
if !isP11 && *cf.encryption {
|
||||
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
||||
if len(passphrase) == 0 {
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
errOut.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
|
||||
if err == ErrNoTerminal {
|
||||
@@ -261,14 +283,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
Curve: curve,
|
||||
}
|
||||
|
||||
if !isP11 {
|
||||
if !isP11 && !isStdio(*cf.outKeyPath) {
|
||||
if _, err := os.Stat(*cf.outKeyPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := os.Stat(*cf.outCertPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
|
||||
if !isStdio(*cf.outCertPath) {
|
||||
if _, err := os.Stat(*cf.outCertPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
|
||||
}
|
||||
}
|
||||
|
||||
var c cert.Certificate
|
||||
@@ -294,7 +318,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*cf.outKeyPath, b, 0600)
|
||||
err = writeOutput(*cf.outKeyPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-key: %s", err)
|
||||
}
|
||||
@@ -305,7 +329,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
return fmt.Errorf("error while marshalling certificate: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*cf.outCertPath, b, 0600)
|
||||
err = writeOutput(*cf.outCertPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-crt: %s", err)
|
||||
}
|
||||
@@ -316,7 +340,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
return fmt.Errorf("error while generating qr code: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*cf.outQRPath, b, 0600)
|
||||
err = writeOutput(*cf.outQRPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-qr: %s", err)
|
||||
}
|
||||
@@ -332,6 +356,7 @@ func caSummary() string {
|
||||
func caHelp(out io.Writer) {
|
||||
cf := newCaFlags()
|
||||
out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n"))
|
||||
out.Write([]byte(stdioHelpText))
|
||||
cf.set.SetOutput(out)
|
||||
cf.set.PrintDefaults()
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func Test_caHelp(t *testing.T) {
|
||||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" ca <flags>: create a self signed certificate authority\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -argon-iterations uint\n"+
|
||||
" \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+
|
||||
" -argon-memory uint\n"+
|
||||
@@ -84,7 +85,7 @@ func Test_ca(t *testing.T) {
|
||||
err: nil,
|
||||
}
|
||||
|
||||
pwPromptOb := "Enter passphrase: "
|
||||
pwPromptEB := "Enter passphrase: "
|
||||
|
||||
// required args
|
||||
assertHelpError(t, ca(
|
||||
@@ -168,8 +169,8 @@ func Test_ca(t *testing.T) {
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.NoError(t, ca(args, ob, eb, testpw))
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, pwPromptEB, eb.String())
|
||||
|
||||
// test encrypted key with passphrase environment variable
|
||||
os.Remove(keyF.Name())
|
||||
@@ -207,8 +208,8 @@ func Test_ca(t *testing.T) {
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.Error(t, ca(args, ob, eb, errpw))
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, pwPromptEB, eb.String())
|
||||
|
||||
// test when user fails to enter a password
|
||||
os.Remove(keyF.Name())
|
||||
@@ -217,8 +218,8 @@ func Test_ca(t *testing.T) {
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, strings.Repeat(pwPromptEB, 5), eb.String()) // prompts 5 times before giving up
|
||||
|
||||
// create valid cert/key for overwrite tests
|
||||
os.Remove(keyF.Name())
|
||||
@@ -247,3 +248,67 @@ func Test_ca(t *testing.T) {
|
||||
os.Remove(keyF.Name())
|
||||
|
||||
}
|
||||
|
||||
func Test_ca_stdio(t *testing.T) {
|
||||
nopw := &StubPasswordReader{}
|
||||
|
||||
keyF, err := os.CreateTemp("", "ca.key")
|
||||
require.NoError(t, err)
|
||||
os.Remove(keyF.Name())
|
||||
defer os.Remove(keyF.Name())
|
||||
|
||||
crtF, err := os.CreateTemp("", "ca.crt")
|
||||
require.NoError(t, err)
|
||||
os.Remove(crtF.Name())
|
||||
defer os.Remove(crtF.Name())
|
||||
|
||||
// out-crt on stdout, out-key on disk
|
||||
ob := &bytes.Buffer{}
|
||||
eb := &bytes.Buffer{}
|
||||
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", keyF.Name()}, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
c, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, c.IsCA())
|
||||
assert.Equal(t, "test-ca", c.Name())
|
||||
|
||||
// out-key on stdout, out-crt on disk
|
||||
os.Remove(keyF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", crtF.Name(), "-out-key", "-"}, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
_, _, curve, err := cert.UnmarshalSigningPrivateKeyFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
|
||||
// dual stdout is rejected up front
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.EqualError(t,
|
||||
ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", "-"}, ob, eb, nopw),
|
||||
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
|
||||
assert.Empty(t, ob.String())
|
||||
|
||||
// an output conflict combined with -encrypt must error BEFORE prompting
|
||||
// for a passphrase; pr would record any read attempt
|
||||
tracker := &trackingPasswordReader{}
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.EqualError(t,
|
||||
ca([]string{"-name", "test-ca", "-duration", "1h", "-encrypt", "-out-crt", "-", "-out-key", "-"}, ob, eb, tracker),
|
||||
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Zero(t, tracker.calls, "passphrase prompt should not have been called")
|
||||
}
|
||||
|
||||
type trackingPasswordReader struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (pr *trackingPasswordReader) ReadPassword() ([]byte, error) {
|
||||
pr.calls++
|
||||
return []byte(""), nil
|
||||
}
|
||||
|
||||
@@ -42,6 +42,8 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
|
||||
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if *cf.outKeyPath != "" {
|
||||
return newHelpErrorf("cannot set -out-key with -pkcs11")
|
||||
}
|
||||
if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
|
||||
return err
|
||||
@@ -69,6 +71,14 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
|
||||
}
|
||||
}
|
||||
|
||||
var claims ioClaims
|
||||
if err := reserveOutputs(&claims,
|
||||
"out-key", *cf.outKeyPath,
|
||||
"out-pub", *cf.outPubPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isP11 {
|
||||
p11Client, err := pkclient.FromUrl(*cf.p11url)
|
||||
if err != nil {
|
||||
@@ -82,12 +92,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
|
||||
return fmt.Errorf("error while getting public key: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
|
||||
err = writeOutput(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-key: %s", err)
|
||||
}
|
||||
}
|
||||
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
|
||||
err = writeOutput(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-pub: %s", err)
|
||||
}
|
||||
@@ -102,6 +112,7 @@ func keygenSummary() string {
|
||||
func keygenHelp(out io.Writer) {
|
||||
cf := newKeygenFlags()
|
||||
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
|
||||
_, _ = out.Write([]byte(stdioHelpText))
|
||||
cf.set.SetOutput(out)
|
||||
cf.set.PrintDefaults()
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ func Test_keygenHelp(t *testing.T) {
|
||||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -curve string\n"+
|
||||
" \tECDH Curve (25519, P256) (default \"25519\")\n"+
|
||||
" -out-key string\n"+
|
||||
@@ -93,3 +94,43 @@ func Test_keygen(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, lPub, 32)
|
||||
}
|
||||
|
||||
func Test_keygen_stdio(t *testing.T) {
|
||||
keyF, err := os.CreateTemp("", "test.key")
|
||||
require.NoError(t, err)
|
||||
os.Remove(keyF.Name())
|
||||
defer os.Remove(keyF.Name())
|
||||
|
||||
pubF, err := os.CreateTemp("", "test.pub")
|
||||
require.NoError(t, err)
|
||||
os.Remove(pubF.Name())
|
||||
defer os.Remove(pubF.Name())
|
||||
|
||||
// out-pub on stdout, out-key on disk
|
||||
ob := &bytes.Buffer{}
|
||||
eb := &bytes.Buffer{}
|
||||
require.NoError(t, keygen([]string{"-out-pub", "-", "-out-key", keyF.Name()}, ob, eb))
|
||||
assert.Empty(t, eb.String())
|
||||
lPub, _, curve, err := cert.UnmarshalPublicKeyFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
assert.Len(t, lPub, 32)
|
||||
|
||||
// out-key on stdout, out-pub on disk
|
||||
os.Remove(keyF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.NoError(t, keygen([]string{"-out-pub", pubF.Name(), "-out-key", "-"}, ob, eb))
|
||||
assert.Empty(t, eb.String())
|
||||
lKey, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
assert.Len(t, lKey, 32)
|
||||
|
||||
// both on stdout is a conflict caught up front
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.EqualError(t, keygen([]string{"-out-pub", "-", "-out-key", "-"}, ob, eb),
|
||||
`-out-key and -out-pub both set to "-", only one output may write to stdout`)
|
||||
assert.Empty(t, ob.String())
|
||||
}
|
||||
|
||||
@@ -22,7 +22,9 @@ func (pr StdinPasswordReader) ReadPassword() ([]byte, error) {
|
||||
}
|
||||
|
||||
password, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println()
|
||||
// Terminal echo is off while reading, so the user's Enter key does not
|
||||
// produce a visible newline. Emit one on stderr to match the prompt.
|
||||
fmt.Fprintln(os.Stderr)
|
||||
|
||||
return password, err
|
||||
}
|
||||
|
||||
@@ -40,11 +40,23 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
rawCert, err := os.ReadFile(*pf.path)
|
||||
var claims ioClaims
|
||||
if err := reserveInputs(&claims, "path", *pf.path); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := reserveOutputs(&claims, "out-qr", *pf.outQRPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawCert, err := readInput("path", *pf.path, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read cert; %s", err)
|
||||
}
|
||||
|
||||
// When the QR is going to stdout, suppress the human-readable text/json
|
||||
// output so the binary stream is not contaminated.
|
||||
qrToStdout := isStdio(*pf.outQRPath)
|
||||
|
||||
var c cert.Certificate
|
||||
var qrBytes []byte
|
||||
part := 0
|
||||
@@ -57,11 +69,13 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
||||
return fmt.Errorf("error while unmarshaling cert: %s", err)
|
||||
}
|
||||
|
||||
if *pf.json {
|
||||
jsonCerts = append(jsonCerts, c)
|
||||
} else {
|
||||
_, _ = out.Write([]byte(c.String()))
|
||||
_, _ = out.Write([]byte("\n"))
|
||||
if !qrToStdout {
|
||||
if *pf.json {
|
||||
jsonCerts = append(jsonCerts, c)
|
||||
} else {
|
||||
_, _ = out.Write([]byte(c.String()))
|
||||
_, _ = out.Write([]byte("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
if *pf.outQRPath != "" {
|
||||
@@ -79,7 +93,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
||||
part++
|
||||
}
|
||||
|
||||
if *pf.json {
|
||||
if *pf.json && !qrToStdout {
|
||||
b, _ := json.Marshal(jsonCerts)
|
||||
_, _ = out.Write(b)
|
||||
_, _ = out.Write([]byte("\n"))
|
||||
@@ -91,7 +105,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
||||
return fmt.Errorf("error while generating qr code: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*pf.outQRPath, b, 0600)
|
||||
err = writeOutput(*pf.outQRPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-qr: %s", err)
|
||||
}
|
||||
@@ -107,6 +121,7 @@ func printSummary() string {
|
||||
func printHelp(out io.Writer) {
|
||||
pf := newPrintFlags()
|
||||
out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n"))
|
||||
out.Write([]byte(stdioHelpText))
|
||||
pf.set.SetOutput(out)
|
||||
pf.set.PrintDefaults()
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ func Test_printHelp(t *testing.T) {
|
||||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -json\n"+
|
||||
" \tOptional: outputs certificates in json format\n"+
|
||||
" -out-qr string\n"+
|
||||
@@ -178,6 +179,44 @@ func Test_printCert(t *testing.T) {
|
||||
ob.String(),
|
||||
)
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// read cert from stdin
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
withStdin(t, bytes.NewReader(p))
|
||||
err = printCert([]string{"-json", "-path", "-"}, ob, eb)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
|
||||
`,
|
||||
ob.String(),
|
||||
)
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// -out-qr - sends only the PNG to stdout, suppressing the cert dump
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
withStdin(t, bytes.NewReader(p))
|
||||
err = printCert([]string{"-path", "-", "-out-qr", "-"}, ob, eb)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, eb.String())
|
||||
stdout := ob.Bytes()
|
||||
require.NotEmpty(t, stdout)
|
||||
// PNG magic, no PEM/JSON noise prepended
|
||||
assert.Equal(t, []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}, stdout[:8])
|
||||
assert.NotContains(t, string(stdout), "NebulaCertificate")
|
||||
assert.NotContains(t, string(stdout), `"details"`)
|
||||
|
||||
// json + out-qr - still suppresses json
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
withStdin(t, bytes.NewReader(p))
|
||||
err = printCert([]string{"-json", "-path", "-", "-out-qr", "-"}, ob, eb)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Equal(t, []byte{0x89, 'P', 'N', 'G'}, ob.Bytes()[:4])
|
||||
assert.NotContains(t, ob.String(), `"details"`)
|
||||
}
|
||||
|
||||
// NewTestCaCert will generate a CA cert
|
||||
|
||||
+42
-20
@@ -85,6 +85,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
|
||||
return newHelpErrorf("cannot set both -in-pub and -out-key")
|
||||
}
|
||||
if isP11 && *sf.outKeyPath != "" {
|
||||
return newHelpErrorf("cannot set -out-key with -pkcs11")
|
||||
}
|
||||
|
||||
var v4Networks []netip.Prefix
|
||||
var v6Networks []netip.Prefix
|
||||
@@ -102,13 +105,35 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
|
||||
}
|
||||
|
||||
if *sf.outKeyPath == "" {
|
||||
*sf.outKeyPath = *sf.name + ".key"
|
||||
}
|
||||
if *sf.outCertPath == "" {
|
||||
*sf.outCertPath = *sf.name + ".crt"
|
||||
}
|
||||
|
||||
var claims ioClaims
|
||||
if err := reserveInputs(&claims,
|
||||
"ca-key", *sf.caKeyPath,
|
||||
"ca-crt", *sf.caCertPath,
|
||||
"in-pub", *sf.inPubPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := reserveOutputs(&claims,
|
||||
"out-key", *sf.outKeyPath,
|
||||
"out-crt", *sf.outCertPath,
|
||||
"out-qr", *sf.outQRPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var curve cert.Curve
|
||||
var caKey []byte
|
||||
|
||||
if !isP11 {
|
||||
var rawCAKey []byte
|
||||
rawCAKey, err := os.ReadFile(*sf.caKeyPath)
|
||||
|
||||
rawCAKey, err = readInput("ca-key", *sf.caKeyPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca-key: %s", err)
|
||||
}
|
||||
@@ -121,7 +146,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
if len(passphrase) == 0 {
|
||||
// ask for a passphrase until we get one
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
errOut.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
|
||||
if errors.Is(err, ErrNoTerminal) {
|
||||
@@ -147,7 +172,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
}
|
||||
}
|
||||
|
||||
rawCACert, err := os.ReadFile(*sf.caCertPath)
|
||||
rawCACert, err := readInput("ca-crt", *sf.caCertPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca-crt: %s", err)
|
||||
}
|
||||
@@ -245,7 +270,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
|
||||
if *sf.inPubPath != "" {
|
||||
var pubCurve cert.Curve
|
||||
rawPub, err := os.ReadFile(*sf.inPubPath)
|
||||
rawPub, err := readInput("in-pub", *sf.inPubPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading in-pub: %s", err)
|
||||
}
|
||||
@@ -266,16 +291,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
pub, rawPriv = newKeypair(curve)
|
||||
}
|
||||
|
||||
if *sf.outKeyPath == "" {
|
||||
*sf.outKeyPath = *sf.name + ".key"
|
||||
}
|
||||
|
||||
if *sf.outCertPath == "" {
|
||||
*sf.outCertPath = *sf.name + ".crt"
|
||||
}
|
||||
|
||||
if _, err := os.Stat(*sf.outCertPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
|
||||
if !isStdio(*sf.outCertPath) {
|
||||
if _, err := os.Stat(*sf.outCertPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
|
||||
}
|
||||
}
|
||||
|
||||
var crts []cert.Certificate
|
||||
@@ -360,11 +379,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
}
|
||||
|
||||
if !isP11 && *sf.inPubPath == "" {
|
||||
if _, err := os.Stat(*sf.outKeyPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
|
||||
if !isStdio(*sf.outKeyPath) {
|
||||
if _, err := os.Stat(*sf.outKeyPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
|
||||
err = writeOutput(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-key: %s", err)
|
||||
}
|
||||
@@ -379,7 +400,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
b = append(b, sb...)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*sf.outCertPath, b, 0600)
|
||||
err = writeOutput(*sf.outCertPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-crt: %s", err)
|
||||
}
|
||||
@@ -390,7 +411,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
return fmt.Errorf("error while generating qr code: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*sf.outQRPath, b, 0600)
|
||||
err = writeOutput(*sf.outQRPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-qr: %s", err)
|
||||
}
|
||||
@@ -440,6 +461,7 @@ func signSummary() string {
|
||||
func signHelp(out io.Writer) {
|
||||
sf := newSignFlags()
|
||||
out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n"))
|
||||
out.Write([]byte(stdioHelpText))
|
||||
sf.set.SetOutput(out)
|
||||
sf.set.PrintDefaults()
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func Test_signHelp(t *testing.T) {
|
||||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" sign <flags>: create and sign a certificate\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -ca-crt string\n"+
|
||||
" \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+
|
||||
" -ca-key string\n"+
|
||||
@@ -376,15 +377,18 @@ func Test_signCert(t *testing.T) {
|
||||
// test with the proper password
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.NoError(t, signCert(args, ob, eb, testpw))
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, "Enter passphrase: ", eb.String())
|
||||
|
||||
// test with the proper password in the environment
|
||||
os.Remove(crtF.Name())
|
||||
os.Remove(keyF.Name())
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.NoError(t, signCert(args, ob, eb, testpw))
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||
|
||||
@@ -395,8 +399,8 @@ func Test_signCert(t *testing.T) {
|
||||
testpw.password = []byte("invalid password")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.Error(t, signCert(args, ob, eb, testpw))
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, "Enter passphrase: ", eb.String())
|
||||
|
||||
// test with the wrong password in environment
|
||||
ob.Reset()
|
||||
@@ -416,8 +420,8 @@ func Test_signCert(t *testing.T) {
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.Error(t, signCert(args, ob, eb, nopw))
|
||||
// normally the user hitting enter on the prompt would add newlines between these
|
||||
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", eb.String())
|
||||
|
||||
// test an error condition
|
||||
ob.Reset()
|
||||
@@ -425,6 +429,106 @@ func Test_signCert(t *testing.T) {
|
||||
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.Error(t, signCert(args, ob, eb, errpw))
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, "Enter passphrase: ", eb.String())
|
||||
}
|
||||
|
||||
func Test_signCert_stdio(t *testing.T) {
|
||||
nopw := &StubPasswordReader{
|
||||
password: []byte(""),
|
||||
err: nil,
|
||||
}
|
||||
|
||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
rawCAKey := cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)
|
||||
|
||||
ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
|
||||
rawCACrt, _ := ca.MarshalPEM()
|
||||
|
||||
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caCrtF.Name())
|
||||
caCrtF.Write(rawCACrt)
|
||||
|
||||
caKeyF, err := os.CreateTemp("", "sign-cert.key")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caKeyF.Name())
|
||||
caKeyF.Write(rawCAKey)
|
||||
|
||||
keyF, err := os.CreateTemp("", "sign.key")
|
||||
require.NoError(t, err)
|
||||
os.Remove(keyF.Name())
|
||||
defer os.Remove(keyF.Name())
|
||||
|
||||
// ca-key on stdin, cert to stdout
|
||||
withStdin(t, bytes.NewReader(rawCAKey))
|
||||
ob := &bytes.Buffer{}
|
||||
eb := &bytes.Buffer{}
|
||||
args := []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", keyF.Name(), "-duration", "100m"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
lCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "stdin-test", lCrt.Name())
|
||||
assert.True(t, lCrt.CheckSignature(caPub))
|
||||
|
||||
// two flags reading from stdin should error before any read attempt;
|
||||
// otherwise an interactive shell would hang on io.ReadAll
|
||||
stdinIn := bytes.NewReader(rawCAKey)
|
||||
withStdin(t, stdinIn)
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-ca-crt", "-", "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw),
|
||||
`-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
|
||||
assert.Equal(t, len(rawCAKey), stdinIn.Len(), "stdin should be untouched when conflict is caught up front")
|
||||
|
||||
// two flags writing to stdout should error before any output is written
|
||||
// AND before stdin is consumed
|
||||
stdinR := bytes.NewReader(rawCAKey)
|
||||
withStdin(t, stdinR)
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", "-", "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw),
|
||||
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
|
||||
assert.Empty(t, ob.String())
|
||||
// stdin should be untouched because the conflict was caught up front
|
||||
assert.Equal(t, len(rawCAKey), stdinR.Len())
|
||||
|
||||
// out-key on stdout, cert on disk
|
||||
keyF2, err := os.CreateTemp("", "sign.key")
|
||||
require.NoError(t, err)
|
||||
os.Remove(keyF2.Name())
|
||||
defer os.Remove(keyF2.Name())
|
||||
crtF, err := os.CreateTemp("", "sign.crt")
|
||||
require.NoError(t, err)
|
||||
os.Remove(crtF.Name())
|
||||
defer os.Remove(crtF.Name())
|
||||
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", "-", "-duration", "100m"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
_, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
|
||||
// in-pub on stdin (caller already has a keypair, only the cert is generated)
|
||||
inPub, _ := x25519Keypair()
|
||||
rawInPub := cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)
|
||||
|
||||
withStdin(t, bytes.NewReader(rawInPub))
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "in-pub-test", "-ip", "1.1.1.1/24", "-in-pub", "-", "-out-crt", "-", "-duration", "100m"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
stdinCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "in-pub-test", stdinCrt.Name())
|
||||
assert.Equal(t, inPub, stdinCrt.PublicKey())
|
||||
}
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// stdioPath is the special path value that selects stdin (for inputs) or
|
||||
// stdout (for outputs) instead of a file on disk.
|
||||
const stdioPath = "-"
|
||||
|
||||
// stdioHelpText is rendered just under the Usage line of each subcommand
|
||||
// help so the - convention is documented once instead of on every flag.
|
||||
const stdioHelpText = " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"
|
||||
|
||||
// stdinReader is the source used when an input flag is set to "-".
|
||||
// It is a package level var so tests can swap in a deterministic reader.
|
||||
// Tests that mutate stdinReader cannot run with t.Parallel().
|
||||
var stdinReader io.Reader = os.Stdin
|
||||
|
||||
// ioClaims tracks which flags have claimed stdin and stdout during a single
|
||||
// command invocation so we can refuse a second flag asking for the same
|
||||
// stream.
|
||||
type ioClaims struct {
|
||||
in string
|
||||
out string
|
||||
}
|
||||
|
||||
func (c *ioClaims) claimIn(flagName string) error {
|
||||
if c.in != "" && c.in != flagName {
|
||||
return fmt.Errorf("-%s and -%s both set to %q, only one input may read from stdin", c.in, flagName, stdioPath)
|
||||
}
|
||||
c.in = flagName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ioClaims) claimOut(flagName string) error {
|
||||
if c.out != "" && c.out != flagName {
|
||||
return fmt.Errorf("-%s and -%s both set to %q, only one output may write to stdout", c.out, flagName, stdioPath)
|
||||
}
|
||||
c.out = flagName
|
||||
return nil
|
||||
}
|
||||
|
||||
// reserveInputs walks alternating (flagName, path) pairs and claims stdin
|
||||
// for any path equal to stdioPath. It must be called before any input is
|
||||
// read so a conflict can be reported immediately instead of blocking on
|
||||
// io.ReadAll while waiting for input that will never arrive.
|
||||
func reserveInputs(claims *ioClaims, pairs ...string) error {
|
||||
return reserveStdio(claims, "reserveInputs", (*ioClaims).claimIn, pairs)
|
||||
}
|
||||
|
||||
// reserveOutputs walks alternating (flagName, path) pairs and claims stdout
|
||||
// for any path equal to stdioPath. It must be called before any output is
|
||||
// written so a conflict cannot leave one stream half written before the
|
||||
// second flag fails.
|
||||
func reserveOutputs(claims *ioClaims, pairs ...string) error {
|
||||
return reserveStdio(claims, "reserveOutputs", (*ioClaims).claimOut, pairs)
|
||||
}
|
||||
|
||||
func reserveStdio(claims *ioClaims, who string, claim func(*ioClaims, string) error, pairs []string) error {
|
||||
if len(pairs)%2 != 0 {
|
||||
panic(who + " requires alternating name, path pairs")
|
||||
}
|
||||
for i := 0; i < len(pairs); i += 2 {
|
||||
name, path := pairs[i], pairs[i+1]
|
||||
if path != stdioPath {
|
||||
continue
|
||||
}
|
||||
if err := claim(claims, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readInput returns the bytes referenced by path, reading from stdin when
|
||||
// path is stdioPath.
|
||||
func readInput(flagName, path string, claims *ioClaims) ([]byte, error) {
|
||||
if path == stdioPath {
|
||||
if err := claims.claimIn(flagName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return io.ReadAll(stdinReader)
|
||||
}
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
// openInput returns a reader for path. When path is stdioPath the returned
|
||||
// reader wraps stdin and Close is a no-op.
|
||||
func openInput(flagName, path string, claims *ioClaims) (io.ReadCloser, error) {
|
||||
if path == stdioPath {
|
||||
if err := claims.claimIn(flagName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return io.NopCloser(stdinReader), nil
|
||||
}
|
||||
return os.Open(path)
|
||||
}
|
||||
|
||||
// writeOutput writes data to path, or to stdout when path is stdioPath. perm
|
||||
// is only used for file output. The caller must have already claimed stdout
|
||||
// via reserveOutputs before invoking with stdioPath.
|
||||
func writeOutput(path string, data []byte, perm os.FileMode, stdout io.Writer) error {
|
||||
if path == stdioPath {
|
||||
_, err := stdout.Write(data)
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, perm)
|
||||
}
|
||||
|
||||
// isStdio reports whether path is the stdio sentinel and so should skip
|
||||
// existence checks like "refuse to overwrite".
|
||||
func isStdio(path string) bool {
|
||||
return path == stdioPath
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// withStdin temporarily replaces stdinReader for the duration of t.
|
||||
func withStdin(t *testing.T, r io.Reader) {
|
||||
t.Helper()
|
||||
prev := stdinReader
|
||||
stdinReader = r
|
||||
t.Cleanup(func() { stdinReader = prev })
|
||||
}
|
||||
|
||||
func Test_readInput_stdin(t *testing.T) {
|
||||
withStdin(t, bytes.NewBufferString("hello"))
|
||||
var claims ioClaims
|
||||
|
||||
got, err := readInput("path", "-", &claims)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("hello"), got)
|
||||
assert.Equal(t, "path", claims.in)
|
||||
}
|
||||
|
||||
func Test_readInput_file(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
p := filepath.Join(dir, "f")
|
||||
require.NoError(t, os.WriteFile(p, []byte("file"), 0600))
|
||||
var claims ioClaims
|
||||
|
||||
got, err := readInput("path", p, &claims)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("file"), got)
|
||||
assert.Empty(t, claims.in)
|
||||
}
|
||||
|
||||
func Test_readInput_doubleStdinErrors(t *testing.T) {
|
||||
withStdin(t, bytes.NewBufferString("hello"))
|
||||
var claims ioClaims
|
||||
|
||||
_, err := readInput("ca-key", "-", &claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = readInput("ca-crt", "-", &claims)
|
||||
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
|
||||
}
|
||||
|
||||
func Test_openInput_stdin(t *testing.T) {
|
||||
withStdin(t, bytes.NewBufferString("hi"))
|
||||
var claims ioClaims
|
||||
|
||||
r, err := openInput("ca", "-", &claims)
|
||||
require.NoError(t, err)
|
||||
defer r.Close()
|
||||
b, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("hi"), b)
|
||||
}
|
||||
|
||||
func Test_openInput_doubleStdinErrors(t *testing.T) {
|
||||
withStdin(t, bytes.NewBufferString("hi"))
|
||||
var claims ioClaims
|
||||
|
||||
r, err := openInput("ca", "-", &claims)
|
||||
require.NoError(t, err)
|
||||
r.Close()
|
||||
|
||||
_, err = openInput("crt", "-", &claims)
|
||||
require.EqualError(t, err, `-ca and -crt both set to "-", only one input may read from stdin`)
|
||||
}
|
||||
|
||||
func Test_writeOutput_stdout(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
|
||||
err := writeOutput("-", []byte("payload"), 0600, out)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "payload", out.String())
|
||||
}
|
||||
|
||||
func Test_writeOutput_file(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
p := filepath.Join(dir, "f")
|
||||
out := &bytes.Buffer{}
|
||||
|
||||
err := writeOutput(p, []byte("payload"), 0600, out)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, out.String())
|
||||
got, err := os.ReadFile(p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("payload"), got)
|
||||
}
|
||||
|
||||
func Test_reserveOutputs_noConflict(t *testing.T) {
|
||||
var claims ioClaims
|
||||
require.NoError(t, reserveOutputs(&claims,
|
||||
"out-key", "/tmp/key",
|
||||
"out-crt", "-",
|
||||
"out-qr", "",
|
||||
))
|
||||
assert.Equal(t, "out-crt", claims.out)
|
||||
}
|
||||
|
||||
func Test_reserveOutputs_conflict(t *testing.T) {
|
||||
var claims ioClaims
|
||||
err := reserveOutputs(&claims,
|
||||
"out-key", "-",
|
||||
"out-crt", "-",
|
||||
)
|
||||
require.EqualError(t, err, `-out-key and -out-crt both set to "-", only one output may write to stdout`)
|
||||
}
|
||||
|
||||
func Test_reserveOutputs_panicsOnOddPairs(t *testing.T) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
require.NotNil(t, r)
|
||||
}()
|
||||
var claims ioClaims
|
||||
_ = reserveOutputs(&claims, "out-key")
|
||||
}
|
||||
|
||||
func Test_reserveInputs_noConflict(t *testing.T) {
|
||||
var claims ioClaims
|
||||
require.NoError(t, reserveInputs(&claims,
|
||||
"ca-key", "/tmp/ca.key",
|
||||
"ca-crt", "-",
|
||||
"in-pub", "",
|
||||
))
|
||||
assert.Equal(t, "ca-crt", claims.in)
|
||||
}
|
||||
|
||||
func Test_reserveInputs_conflict(t *testing.T) {
|
||||
var claims ioClaims
|
||||
err := reserveInputs(&claims,
|
||||
"ca-key", "-",
|
||||
"ca-crt", "-",
|
||||
)
|
||||
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
|
||||
}
|
||||
|
||||
func Test_claimIn_idempotent(t *testing.T) {
|
||||
// pre-claim then a lazy re-claim of the same flag should be a no-op
|
||||
var claims ioClaims
|
||||
require.NoError(t, claims.claimIn("ca-key"))
|
||||
require.NoError(t, claims.claimIn("ca-key"))
|
||||
assert.Equal(t, "ca-key", claims.in)
|
||||
}
|
||||
|
||||
func Test_claimOut_idempotent(t *testing.T) {
|
||||
var claims ioClaims
|
||||
require.NoError(t, claims.claimOut("out-crt"))
|
||||
require.NoError(t, claims.claimOut("out-crt"))
|
||||
assert.Equal(t, "out-crt", claims.out)
|
||||
}
|
||||
|
||||
func Test_isStdio(t *testing.T) {
|
||||
assert.True(t, isStdio("-"))
|
||||
assert.False(t, isStdio(""))
|
||||
assert.False(t, isStdio("./-"))
|
||||
assert.False(t, isStdio("foo"))
|
||||
}
|
||||
@@ -39,18 +39,26 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
caFile, err := os.Open(*vf.caPath)
|
||||
var claims ioClaims
|
||||
if err := reserveInputs(&claims,
|
||||
"ca", *vf.caPath,
|
||||
"crt", *vf.certPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
caReader, err := openInput("ca", *vf.caPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca: %w", err)
|
||||
}
|
||||
defer caFile.Close()
|
||||
defer caReader.Close()
|
||||
|
||||
caPool, err := cert.NewCAPoolFromPEMReader(caFile)
|
||||
caPool, err := cert.NewCAPoolFromPEMReader(caReader)
|
||||
if err != nil && !errors.Is(err, cert.ErrExpired) {
|
||||
return fmt.Errorf("error while adding ca cert to pool: %w", err)
|
||||
}
|
||||
|
||||
rawCert, err := os.ReadFile(*vf.certPath)
|
||||
rawCert, err := readInput("crt", *vf.certPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read crt: %w", err)
|
||||
}
|
||||
@@ -85,6 +93,7 @@ func verifySummary() string {
|
||||
func verifyHelp(out io.Writer) {
|
||||
vf := newVerifyFlags()
|
||||
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
|
||||
_, _ = out.Write([]byte(stdioHelpText))
|
||||
vf.set.SetOutput(out)
|
||||
vf.set.PrintDefaults()
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ func Test_verifyHelp(t *testing.T) {
|
||||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority.\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -ca string\n"+
|
||||
" \tRequired: path to a file containing one or more ca certificates\n"+
|
||||
" -crt string\n"+
|
||||
@@ -122,3 +123,46 @@ func Test_verify(t *testing.T) {
|
||||
assert.Empty(t, eb.String())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_verify_stdio(t *testing.T) {
|
||||
ob := &bytes.Buffer{}
|
||||
eb := &bytes.Buffer{}
|
||||
|
||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
|
||||
caPEM, _ := ca.MarshalPEM()
|
||||
|
||||
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
||||
crtPEM, _ := crt.MarshalPEM()
|
||||
|
||||
caFile, err := os.CreateTemp("", "verify-ca")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caFile.Name())
|
||||
caFile.Write(caPEM)
|
||||
|
||||
// crt on stdin, ca on disk
|
||||
withStdin(t, bytes.NewReader(crtPEM))
|
||||
require.NoError(t, verify([]string{"-ca", caFile.Name(), "-crt", "-"}, ob, eb))
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// ca on stdin, crt on disk
|
||||
certFile, err := os.CreateTemp("", "verify-cert")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(certFile.Name())
|
||||
certFile.Write(crtPEM)
|
||||
|
||||
withStdin(t, bytes.NewReader(caPEM))
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.NoError(t, verify([]string{"-ca", "-", "-crt", certFile.Name()}, ob, eb))
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// both flags on stdin should error
|
||||
withStdin(t, bytes.NewReader(caPEM))
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.EqualError(t, verify([]string{"-ca", "-", "-crt", "-"}, ob, eb),
|
||||
`-ca and -crt both set to "-", only one input may read from stdin`)
|
||||
}
|
||||
|
||||
@@ -61,9 +61,12 @@ func main() {
|
||||
}
|
||||
|
||||
if *configPath == "" {
|
||||
fmt.Println("-config flag must be set")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
p, err := config.DefaultPath()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
*configPath = p
|
||||
}
|
||||
|
||||
c := config.NewC(l)
|
||||
|
||||
@@ -3,8 +3,6 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/slackhq/nebula"
|
||||
@@ -57,24 +55,13 @@ func (p *program) Stop(s service.Service) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func fileExists(filename string) bool {
|
||||
_, err := os.Stat(filename)
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
|
||||
if *configPath == "" {
|
||||
ex, err := os.Executable()
|
||||
p, err := config.DefaultPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*configPath = filepath.Dir(ex) + "/config.yaml"
|
||||
if !fileExists(*configPath) {
|
||||
*configPath = filepath.Dir(ex) + "/config.yml"
|
||||
}
|
||||
*configPath = p
|
||||
}
|
||||
|
||||
svcConfig := &service.Config{
|
||||
|
||||
+6
-3
@@ -50,9 +50,12 @@ func main() {
|
||||
}
|
||||
|
||||
if *configPath == "" {
|
||||
fmt.Println("-config flag must be set")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
p, err := config.DefaultPath()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
*configPath = p
|
||||
}
|
||||
|
||||
l := logging.NewLogger(os.Stdout)
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// DefaultPath returns a path to a config file alongside the running executable, preferring config.yaml over config.yml.
|
||||
// If neither file exists an error is returned that names both paths checked.
|
||||
func DefaultPath() (string, error) {
|
||||
ex, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return defaultPathInDir(filepath.Dir(ex))
|
||||
}
|
||||
|
||||
func defaultPathInDir(dir string) (string, error) {
|
||||
yamlPath := filepath.Join(dir, "config.yaml")
|
||||
if _, err := os.Stat(yamlPath); err == nil {
|
||||
return yamlPath, nil
|
||||
}
|
||||
ymlPath := filepath.Join(dir, "config.yml")
|
||||
if _, err := os.Stat(ymlPath); err == nil {
|
||||
return ymlPath, nil
|
||||
}
|
||||
return "", fmt.Errorf("no default config found at %s or %s", yamlPath, ymlPath)
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefaultPathInDir(t *testing.T) {
|
||||
t.Run("prefers config.yaml when both exist", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
want := filepath.Join(dir, "config.yaml")
|
||||
other := filepath.Join(dir, "config.yml")
|
||||
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
|
||||
require.NoError(t, os.WriteFile(other, []byte("a: 2"), 0644))
|
||||
|
||||
got, err := defaultPathInDir(dir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
|
||||
t.Run("returns config.yaml when only it exists", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
want := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
|
||||
|
||||
got, err := defaultPathInDir(dir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
|
||||
t.Run("falls back to config.yml when only it exists", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
want := filepath.Join(dir, "config.yml")
|
||||
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
|
||||
|
||||
got, err := defaultPathInDir(dir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
|
||||
t.Run("errors when neither exists and names both paths", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
got, err := defaultPathInDir(dir)
|
||||
assert.Empty(t, got)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yaml"))
|
||||
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yml"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultPath(t *testing.T) {
|
||||
got, err := DefaultPath()
|
||||
if err != nil {
|
||||
ex, exErr := os.Executable()
|
||||
require.NoError(t, exErr)
|
||||
assert.Contains(t, err.Error(), filepath.Dir(ex))
|
||||
return
|
||||
}
|
||||
ex, err := os.Executable()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Dir(ex), filepath.Dir(got))
|
||||
assert.Contains(t, []string{"config.yaml", "config.yml"}, filepath.Base(got))
|
||||
}
|
||||
+12
-42
@@ -11,7 +11,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
@@ -45,19 +44,16 @@ type connectionManager struct {
|
||||
inactivityTimeout atomic.Int64
|
||||
dropInactive atomic.Bool
|
||||
|
||||
metricsTxPunchy metrics.Counter
|
||||
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
||||
cm := &connectionManager{
|
||||
hostMap: hm,
|
||||
l: l,
|
||||
punchy: p,
|
||||
relayUsed: make(map[uint32]struct{}),
|
||||
relayUsedLock: &sync.RWMutex{},
|
||||
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
||||
hostMap: hm,
|
||||
l: l,
|
||||
punchy: p,
|
||||
relayUsed: make(map[uint32]struct{}),
|
||||
relayUsedLock: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
cm.reload(c, true)
|
||||
@@ -369,7 +365,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
|
||||
if !outTraffic {
|
||||
// Send a punch packet to keep the NAT state alive
|
||||
cm.sendPunch(hostinfo)
|
||||
cm.punchy.SendPunch(hostinfo)
|
||||
}
|
||||
|
||||
return decision, hostinfo, primary
|
||||
@@ -400,17 +396,16 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
|
||||
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
||||
// Just maintain NAT state if configured to do so.
|
||||
cm.sendPunch(hostinfo)
|
||||
cm.punchy.SendPunch(hostinfo)
|
||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
||||
return doNothing, nil, nil
|
||||
}
|
||||
|
||||
if cm.punchy.GetTargetEverything() {
|
||||
// This is similar to the old punchy behavior with a slight optimization.
|
||||
// We aren't receiving traffic but we are sending it, punch on all known
|
||||
// ips in case we need to re-prime NAT state
|
||||
cm.sendPunch(hostinfo)
|
||||
}
|
||||
// We aren't receiving traffic but we are sending it. The outbound
|
||||
// traffic itself refreshes the primary remote's NAT state; this
|
||||
// fans out to non-primary remotes, but only if target_all_remotes
|
||||
// is configured.
|
||||
cm.punchy.SendPunchToAll(hostinfo)
|
||||
|
||||
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||
@@ -512,31 +507,6 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
if !cm.punchy.GetPunch() {
|
||||
// Punching is disabled
|
||||
return
|
||||
}
|
||||
|
||||
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
||||
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
|
||||
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
|
||||
// would lose the ability to notify us and punchy.respond would become unreliable.
|
||||
return
|
||||
}
|
||||
|
||||
if cm.punchy.GetTargetEverything() {
|
||||
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||
cm.metricsTxPunchy.Inc(1)
|
||||
cm.intf.outside.WriteTo([]byte{1}, addr)
|
||||
})
|
||||
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
cm.metricsTxPunchy.Inc(1)
|
||||
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
cs := cm.intf.pki.getCertState()
|
||||
curCrt := hostinfo.ConnectionState.myCert
|
||||
|
||||
@@ -64,7 +64,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(test.NewLogger())
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
|
||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||
nc.intf = ifce
|
||||
p := []byte("")
|
||||
@@ -146,7 +146,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(test.NewLogger())
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
|
||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||
nc.intf = ifce
|
||||
p := []byte("")
|
||||
@@ -233,7 +233,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
||||
conf.Settings["tunnels"] = map[string]any{
|
||||
"drop_inactive": true,
|
||||
}
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
|
||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||
assert.True(t, nc.dropInactive.Load())
|
||||
nc.intf = ifce
|
||||
@@ -358,7 +358,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(test.NewLogger())
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
|
||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||
nc.intf = ifce
|
||||
ifce.connectionManager = nc
|
||||
|
||||
+5
-4
@@ -7,13 +7,14 @@ import (
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/handshake"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
)
|
||||
|
||||
const ReplayWindow = 1024
|
||||
|
||||
type ConnectionState struct {
|
||||
eKey *NebulaCipherState
|
||||
dKey *NebulaCipherState
|
||||
eKey noiseutil.CipherState
|
||||
dKey noiseutil.CipherState
|
||||
myCert cert.Certificate
|
||||
peerCert *cert.CachedCertificate
|
||||
initiator bool
|
||||
@@ -31,8 +32,8 @@ func newConnectionStateFromResult(r *handshake.Result) *ConnectionState {
|
||||
myCert: r.MyCert,
|
||||
initiator: r.Initiator,
|
||||
peerCert: r.RemoteCert,
|
||||
eKey: NewNebulaCipherState(r.EKey),
|
||||
dKey: NewNebulaCipherState(r.DKey),
|
||||
eKey: noiseutil.NewCipherState(r.EKey, r.Cipher),
|
||||
dKey: noiseutil.NewCipherState(r.DKey, r.Cipher),
|
||||
window: NewBits(ReplayWindow),
|
||||
}
|
||||
ci.messageCounter.Add(r.MessageIndex)
|
||||
|
||||
+92
-26
@@ -11,19 +11,21 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
type dnsServer struct {
|
||||
sync.RWMutex
|
||||
l *slog.Logger
|
||||
ctx context.Context
|
||||
dnsMap4 map[string]netip.Addr
|
||||
dnsMap6 map[string]netip.Addr
|
||||
hostMap *HostMap
|
||||
myVpnAddrsTable *bart.Lite
|
||||
l *slog.Logger
|
||||
ctx context.Context
|
||||
dnsMap4 map[string]netip.Addr
|
||||
dnsMap6 map[string]netip.Addr
|
||||
hostMap *HostMap
|
||||
pki *PKI
|
||||
|
||||
// selfHost is the cached FQDN we last seeded for ourselves
|
||||
selfHost string
|
||||
|
||||
mux *dns.ServeMux
|
||||
|
||||
@@ -55,14 +57,14 @@ type dnsServer struct {
|
||||
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
|
||||
// watcher that tears the listener down on nebula shutdown. The returned
|
||||
// pointer is always non-nil, even on error.
|
||||
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
|
||||
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*dnsServer, error) {
|
||||
ds := &dnsServer{
|
||||
l: l,
|
||||
ctx: ctx,
|
||||
dnsMap4: make(map[string]netip.Addr),
|
||||
dnsMap6: make(map[string]netip.Addr),
|
||||
hostMap: hostMap,
|
||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||
l: l,
|
||||
ctx: ctx,
|
||||
dnsMap4: make(map[string]netip.Addr),
|
||||
dnsMap6: make(map[string]netip.Addr),
|
||||
hostMap: hostMap,
|
||||
pki: pki,
|
||||
}
|
||||
ds.mux = dns.NewServeMux()
|
||||
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
||||
@@ -76,6 +78,7 @@ func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState,
|
||||
if err := ds.reload(c, true); err != nil {
|
||||
return ds, err
|
||||
}
|
||||
ds.seedSelf()
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
@@ -113,7 +116,7 @@ func (d *dnsServer) reload(c *config.C, initial bool) error {
|
||||
d.Stop()
|
||||
}
|
||||
// Drop any records that accumulated while enabled; a later re-enable
|
||||
// will repopulate from fresh handshakes.
|
||||
// will repopulate from fresh handshakes and a fresh seedSelf.
|
||||
d.clearRecords()
|
||||
return nil
|
||||
}
|
||||
@@ -121,17 +124,14 @@ func (d *dnsServer) reload(c *config.C, initial bool) error {
|
||||
if running == nil {
|
||||
// Was disabled (or never started); bring it up now.
|
||||
go d.Start()
|
||||
return nil
|
||||
} else if !sameAddr {
|
||||
d.shutdownServer(running, runningStarted, "reload")
|
||||
// Old Start goroutine has now exited; bring up a fresh listener on the new address.
|
||||
go d.Start()
|
||||
}
|
||||
|
||||
if sameAddr {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.shutdownServer(running, runningStarted, "reload")
|
||||
// Old Start goroutine has now exited; bring up a fresh listener on the
|
||||
// new address.
|
||||
go d.Start()
|
||||
// Refresh the self entry every enabled reload so cert renewals that change our name or VPN addresses are picked up.
|
||||
d.seedSelf()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -249,6 +249,20 @@ func (d *dnsServer) QueryCert(data string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves.
|
||||
// Answer self lookups straight from the local cert state.
|
||||
if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) {
|
||||
c := cs.GetDefaultCertificate()
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
b, err := c.MarshalJSON()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
hostinfo := d.hostMap.QueryVpnAddr(ip)
|
||||
if hostinfo == nil {
|
||||
return ""
|
||||
@@ -266,12 +280,60 @@ func (d *dnsServer) QueryCert(data string) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// clearRecords drops all DNS records.
|
||||
// clearRecords drops all DNS records, including the self entry.
|
||||
func (d *dnsServer) clearRecords() {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
clear(d.dnsMap4)
|
||||
clear(d.dnsMap6)
|
||||
d.selfHost = ""
|
||||
}
|
||||
|
||||
// seedSelf inserts (or refreshes) a record for our own cert name pointing at our VPN addresses,
|
||||
// so a single-lighthouse network can resolve the lighthouse's own hostname without the two-process workaround.
|
||||
func (d *dnsServer) seedSelf() {
|
||||
if !d.enabled.Load() {
|
||||
return
|
||||
}
|
||||
cs := d.certState()
|
||||
if cs == nil {
|
||||
return
|
||||
}
|
||||
c := cs.GetDefaultCertificate()
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
newHost := strings.ToLower(c.Name()) + "."
|
||||
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
if d.selfHost != "" && d.selfHost != newHost {
|
||||
delete(d.dnsMap4, d.selfHost)
|
||||
delete(d.dnsMap6, d.selfHost)
|
||||
}
|
||||
d.selfHost = newHost
|
||||
delete(d.dnsMap4, newHost)
|
||||
delete(d.dnsMap6, newHost)
|
||||
haveV4, haveV6 := false, false
|
||||
for _, addr := range cs.myVpnAddrs {
|
||||
if addr.Is4() && !haveV4 {
|
||||
d.dnsMap4[newHost] = addr
|
||||
haveV4 = true
|
||||
} else if addr.Is6() && !haveV6 {
|
||||
d.dnsMap6[newHost] = addr
|
||||
haveV6 = true
|
||||
}
|
||||
if haveV4 && haveV6 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsServer) certState() *CertState {
|
||||
if d.pki == nil {
|
||||
return nil
|
||||
}
|
||||
return d.pki.getCertState()
|
||||
}
|
||||
|
||||
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
||||
@@ -309,8 +371,12 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
cs := d.certState()
|
||||
if cs == nil || cs.myVpnAddrsTable == nil {
|
||||
return false
|
||||
}
|
||||
//if we found it in this table, it's good
|
||||
return d.myVpnAddrsTable.Contains(b)
|
||||
return cs.myVpnAddrsTable.Contains(b)
|
||||
}
|
||||
|
||||
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
|
||||
@@ -9,7 +9,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -276,6 +279,92 @@ func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// newTestPKI builds a minimal *PKI with a single v1 cert whose name and
|
||||
// VPN addresses are caller-provided, suitable for exercising seedSelf and
|
||||
// QueryCert self handling.
|
||||
func newTestPKI(t *testing.T, name string, addrs []netip.Addr) *PKI {
|
||||
t.Helper()
|
||||
networks := make([]netip.Prefix, 0, len(addrs))
|
||||
for _, a := range addrs {
|
||||
bits := 32
|
||||
if a.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
networks = append(networks, netip.PrefixFrom(a, bits))
|
||||
}
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||
c, _, _, _ := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, ca, caKey, name, time.Time{}, time.Time{}, networks, nil, nil)
|
||||
|
||||
addrsTable := new(bart.Lite)
|
||||
for _, a := range addrs {
|
||||
addrsTable.Insert(netip.PrefixFrom(a, a.BitLen()))
|
||||
}
|
||||
|
||||
cs := &CertState{
|
||||
v2Cert: c,
|
||||
initiatingVersion: cert.Version2,
|
||||
myVpnAddrs: addrs,
|
||||
myVpnAddrsTable: addrsTable,
|
||||
}
|
||||
pki := &PKI{}
|
||||
pki.cs.Store(cs)
|
||||
return pki
|
||||
}
|
||||
|
||||
func TestDnsServer_seedSelf_addsOwnRecord(t *testing.T) {
|
||||
ds, c := newTestDnsServer(t)
|
||||
myV4 := netip.MustParseAddr("10.0.0.1")
|
||||
myV6 := netip.MustParseAddr("fd00::1")
|
||||
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4, myV6})
|
||||
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
||||
require.NoError(t, ds.reload(c, true))
|
||||
|
||||
ds.seedSelf()
|
||||
got4, exists := ds.Query(dns.TypeA, "lighthouse.")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, myV4, got4)
|
||||
got6, exists := ds.Query(dns.TypeAAAA, "lighthouse.")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, myV6, got6)
|
||||
}
|
||||
|
||||
func TestDnsServer_seedSelf_disabled_noOp(t *testing.T) {
|
||||
ds, c := newTestDnsServer(t)
|
||||
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")})
|
||||
setDnsConfig(c, "127.0.0.1", "0", true, false)
|
||||
require.NoError(t, ds.reload(c, true))
|
||||
|
||||
ds.seedSelf()
|
||||
_, exists := ds.Query(dns.TypeA, "lighthouse.")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestDnsServer_clearRecords_dropsSelfHost(t *testing.T) {
|
||||
ds, c := newTestDnsServer(t)
|
||||
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")})
|
||||
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
||||
require.NoError(t, ds.reload(c, true))
|
||||
ds.seedSelf()
|
||||
require.NotEmpty(t, ds.selfHost)
|
||||
|
||||
ds.clearRecords()
|
||||
assert.Empty(t, ds.selfHost)
|
||||
_, exists := ds.Query(dns.TypeA, "lighthouse.")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestDnsServer_QueryCert_returnsOwnCert(t *testing.T) {
|
||||
ds, _ := newTestDnsServer(t)
|
||||
myV4 := netip.MustParseAddr("10.0.0.1")
|
||||
ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4})
|
||||
|
||||
got := ds.QueryCert(myV4.String() + ".")
|
||||
assert.NotEmpty(t, got, "TXT lookup of our own VPN address should return our cert")
|
||||
|
||||
other := netip.MustParseAddr("10.0.0.99")
|
||||
assert.Empty(t, ds.QueryCert(other.String()+"."), "unknown peer IP should return nothing")
|
||||
}
|
||||
|
||||
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
|
||||
port := freeUDPPort(t)
|
||||
ds, c := newTestDnsServer(t)
|
||||
|
||||
+3
-7
@@ -18,14 +18,10 @@ import (
|
||||
// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain
|
||||
// before failing the assertion.
|
||||
//
|
||||
// IgnoreCurrent is necessary in the parallelized suite: other tests can
|
||||
// leave goroutines mid-shutdown when this one runs (Stop is async, the
|
||||
// wg.Wait() drain is not blocking on test return). We're checking that
|
||||
// *this* test's setup tears down cleanly, not that the whole suite is
|
||||
// idle at this moment. Intentionally NOT t.Parallel()'d for the same
|
||||
// reason — concurrent test goroutines would always show up.
|
||||
// Intentionally NOT t.Parallel()'d: concurrent tests would have their own
|
||||
// goroutines running and trip the assertion.
|
||||
func TestNoGoroutineLeaks(t *testing.T) {
|
||||
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
|
||||
defer goleak.VerifyNone(t)
|
||||
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
//go:build e2e_testing
|
||||
// +build e2e_testing
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/pem"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestSSHDLifecycle(t *testing.T) {
|
||||
// TestSSHDLifecycle exercises the in-process sshd through several config reloads and a Control.Stop.
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(
|
||||
cert.Version1, cert.Curve_CURVE25519,
|
||||
time.Now(), time.Now().Add(10*time.Minute),
|
||||
nil, nil, []string{},
|
||||
)
|
||||
|
||||
hostKeyPEM := generateSSHHostKey(t)
|
||||
clientSigner, clientAuthKey := generateSSHClientKey(t)
|
||||
sshdAddr := allocLoopbackPort(t)
|
||||
|
||||
overrides := m{
|
||||
"sshd": m{
|
||||
"enabled": true,
|
||||
"listen": sshdAddr,
|
||||
"host_key": hostKeyPEM,
|
||||
"authorized_users": []m{{
|
||||
"user": "tester",
|
||||
"keys": []string{clientAuthKey},
|
||||
}},
|
||||
},
|
||||
}
|
||||
control, _, _, _ := newSimpleServer(cert.Version1, ca, caKey, "sshd-test", "10.222.0.1/24", overrides)
|
||||
control.Start()
|
||||
t.Cleanup(func() { control.Stop() })
|
||||
|
||||
// sshd binds in a goroutine after Start returns; wait for it.
|
||||
require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
|
||||
"sshd never started listening")
|
||||
|
||||
for i := 1; i <= 3; i++ {
|
||||
out := sshExecReload(t, sshdAddr, clientSigner)
|
||||
assert.Contains(t, out, "Reloading config", "reload cycle %d", i)
|
||||
require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
|
||||
"sshd not listening after reload cycle %d", i)
|
||||
}
|
||||
|
||||
control.Stop()
|
||||
require.Eventually(t, func() bool { return !canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond,
|
||||
"sshd still listening after Control.Stop")
|
||||
}
|
||||
|
||||
func canDial(addr string) bool {
|
||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = c.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
// allocLoopbackPort grabs an unused TCP port on 127.0.0.1, closes it, and returns the address. There
|
||||
// is a small race between releasing the port and the sshd reclaiming it; in practice the OS keeps the
|
||||
// port available long enough for the test to bind it.
|
||||
func allocLoopbackPort(t *testing.T) string {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := l.Addr().String()
|
||||
require.NoError(t, l.Close())
|
||||
return addr
|
||||
}
|
||||
|
||||
func generateSSHHostKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
block, err := ssh.MarshalPrivateKey(priv, "nebula-e2e-host")
|
||||
require.NoError(t, err)
|
||||
return string(pem.EncodeToMemory(block))
|
||||
}
|
||||
|
||||
func generateSSHClientKey(t *testing.T) (ssh.Signer, string) {
|
||||
t.Helper()
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
signer, err := ssh.NewSignerFromKey(priv)
|
||||
require.NoError(t, err)
|
||||
auth := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey())))
|
||||
return signer, auth
|
||||
}
|
||||
|
||||
func sshExecReload(t *testing.T, addr string, signer ssh.Signer) string {
|
||||
t.Helper()
|
||||
cfg := &ssh.ClientConfig{
|
||||
User: "tester",
|
||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
client, err := ssh.Dial("tcp", addr, cfg)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
sess, err := client.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer sess.Close()
|
||||
|
||||
// reload tears the channel down before sending exit-status, so Output returns an error on the
|
||||
// channel close. The output buffer still has whatever the reload callback wrote before that.
|
||||
out, _ := sess.Output("reload")
|
||||
return string(out)
|
||||
}
|
||||
@@ -138,6 +138,14 @@ listen:
|
||||
# max, net.core.rmem_max and net.core.wmem_max
|
||||
#read_buffer: 10485760
|
||||
#write_buffer: 10485760
|
||||
|
||||
# On Windows only
|
||||
# When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to UDP at the listener port.
|
||||
# WFP sits below Windows Defender Firewall, so this lets peer handshakes reach Nebula's outside socket regardless
|
||||
# of WDF's inbound rules.
|
||||
# Default true; set to false to leave WDF in charge of inbound decisions on the listener port. Not reloadable.
|
||||
#windows_bypass_wdf: true
|
||||
|
||||
# By default, Nebula replies to packets it has no tunnel for with a "recv_error" packet. This packet helps speed up reconnection
|
||||
# in the case that Nebula on either side did not shut down cleanly. This response can be abused as a way to discover if Nebula is running
|
||||
# on a host though. This option lets you configure if you want to send "recv_error" packets always, never, or only to private network remotes.
|
||||
@@ -163,17 +171,21 @@ listen:
|
||||
|
||||
punchy:
|
||||
# Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
|
||||
# This setting is reloadable.
|
||||
punch: true
|
||||
|
||||
# respond means that a node you are trying to reach will connect back out to you if your hole punching fails
|
||||
# this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT
|
||||
# Default is false
|
||||
# This setting is reloadable.
|
||||
#respond: true
|
||||
|
||||
# delays a punch response for misbehaving NATs, default is 1 second.
|
||||
# This setting is reloadable.
|
||||
#delay: 1s
|
||||
|
||||
# set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect.
|
||||
# This setting is reloadable.
|
||||
#respond_delay: 5s
|
||||
|
||||
# Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes
|
||||
@@ -282,6 +294,24 @@ tun:
|
||||
# metric: 100
|
||||
# install: true
|
||||
|
||||
# On Windows only, sets the network category of the nebula interface. Without this, Windows often
|
||||
# leaves the network as "Unidentified" and treats it as Public, which makes the host firewall more
|
||||
# restrictive than you usually want for an overlay between trusted peers. Valid values:
|
||||
# private - treat the nebula network as a private/trusted network (default)
|
||||
# public - treat it as a public/untrusted network
|
||||
# domain - treat it as a domain-authenticated network
|
||||
# unset - leave whatever Windows decided alone
|
||||
# Not reloadable.
|
||||
#network_category: private
|
||||
|
||||
# On Windows only
|
||||
# When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to the nebula adapter LUID.
|
||||
# WFP sits below Windows Defender Firewall, so this lets inbound traffic through regardless of WDF rules.
|
||||
# Filters are auto-removed when the adapter goes away.
|
||||
# See listen.windows_bypass_wdf for the matching control over inbound to nebula's outside UDP listener.
|
||||
# Default true; set to false to leave WDF in charge of inbound decisions on the nebula interface. Not reloadable.
|
||||
#windows_bypass_wdf: true
|
||||
|
||||
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
|
||||
# in nebula configuration files. Default false, not reloadable.
|
||||
#use_system_route_table: false
|
||||
|
||||
+43
-26
@@ -58,8 +58,9 @@ type Firewall struct {
|
||||
routableNetworks *bart.Lite
|
||||
|
||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||
assignedNetworks []netip.Prefix
|
||||
hasUnsafeNetworks bool
|
||||
assignedNetworks []netip.Prefix
|
||||
// unsafeNetworks is the list of unsafe networks issued to us in the certificate
|
||||
unsafeNetworks []netip.Prefix
|
||||
|
||||
rules string
|
||||
rulesVersion uint16
|
||||
@@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
|
||||
assignedNetworks = append(assignedNetworks, network)
|
||||
}
|
||||
|
||||
hasUnsafeNetworks := false
|
||||
for _, n := range c.UnsafeNetworks() {
|
||||
unsafeNetworks := c.UnsafeNetworks()
|
||||
for _, n := range unsafeNetworks {
|
||||
routableNetworks.Insert(n)
|
||||
hasUnsafeNetworks = true
|
||||
}
|
||||
|
||||
return &Firewall{
|
||||
@@ -169,15 +169,15 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
|
||||
Conns: make(map[firewall.Packet]*conn),
|
||||
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
||||
},
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
routableNetworks: routableNetworks,
|
||||
assignedNetworks: assignedNetworks,
|
||||
hasUnsafeNetworks: hasUnsafeNetworks,
|
||||
l: l,
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
routableNetworks: routableNetworks,
|
||||
assignedNetworks: assignedNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
l: l,
|
||||
|
||||
incomingMetrics: firewallMetrics{
|
||||
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
|
||||
@@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
|
||||
}
|
||||
|
||||
if localCidr == "" {
|
||||
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
||||
if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny {
|
||||
flc.Any = true
|
||||
return nil
|
||||
}
|
||||
@@ -1055,7 +1055,6 @@ func (r *rule) sanity() error {
|
||||
}
|
||||
|
||||
func parsePort(s string) (int32, int32, error) {
|
||||
var err error
|
||||
const notAPort int32 = -2
|
||||
if s == "any" {
|
||||
return firewall.PortAny, firewall.PortAny, nil
|
||||
@@ -1064,11 +1063,11 @@ func parsePort(s string) (int32, int32, error) {
|
||||
return firewall.PortFragment, firewall.PortFragment, nil
|
||||
}
|
||||
if !strings.Contains(s, `-`) {
|
||||
rPort, err := strconv.Atoi(s)
|
||||
rPort, err := parsePortValue("", s)
|
||||
if err != nil {
|
||||
return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s)
|
||||
return notAPort, notAPort, err
|
||||
}
|
||||
return int32(rPort), int32(rPort), nil
|
||||
return rPort, rPort, nil
|
||||
}
|
||||
|
||||
sPorts := strings.SplitN(s, `-`, 2)
|
||||
@@ -1079,22 +1078,40 @@ func parsePort(s string) (int32, int32, error) {
|
||||
return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
|
||||
}
|
||||
|
||||
rStartPort, err := strconv.Atoi(sPorts[0])
|
||||
startPort, err := parsePortValue("beginning range ", sPorts[0])
|
||||
if err != nil {
|
||||
return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
|
||||
return notAPort, notAPort, err
|
||||
}
|
||||
|
||||
rEndPort, err := strconv.Atoi(sPorts[1])
|
||||
endPort, err := parsePortValue("ending range ", sPorts[1])
|
||||
if err != nil {
|
||||
return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
|
||||
return notAPort, notAPort, err
|
||||
}
|
||||
|
||||
startPort := int32(rStartPort)
|
||||
endPort := int32(rEndPort)
|
||||
|
||||
if startPort == firewall.PortAny {
|
||||
endPort = firewall.PortAny
|
||||
}
|
||||
|
||||
return startPort, endPort, nil
|
||||
}
|
||||
|
||||
// parsePortValue accepts a base-10 decimal in [0, 65535] and returns it
|
||||
// widened to int32. Using strconv.ParseUint with bitSize 16 rejects
|
||||
// negative input, out-of-range input (>65535), and any non-decimal byte
|
||||
// by construction, so the int32 widening that follows is provably safe
|
||||
// and cannot collide with firewall.PortAny (0) or firewall.PortFragment
|
||||
// (-1) via integer truncation.
|
||||
//
|
||||
// prefix is prepended to both error messages so callers can disambiguate
|
||||
// the single-port path (prefix="") from the range bounds (prefix="beginning
|
||||
// range " / "ending range "), preserving the historical error strings.
|
||||
func parsePortValue(prefix, s string) (int32, error) {
|
||||
n, err := strconv.ParseUint(s, 10, 16)
|
||||
if err == nil {
|
||||
return int32(n), nil
|
||||
}
|
||||
if errors.Is(err, strconv.ErrRange) {
|
||||
return 0, fmt.Errorf("%sout of range [0,65535]; `%s`", prefix, s)
|
||||
}
|
||||
return 0, fmt.Errorf("%swas not a number; `%s`", prefix, s)
|
||||
}
|
||||
|
||||
@@ -1029,6 +1029,75 @@ func Test_parsePort(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test_parsePort_invalid covers inputs that must error. The named bug is
|
||||
// that int32(strconv.Atoi("4294967296")) truncates to 0 == firewall.PortAny,
|
||||
// silently turning a typo into a match-all-ports rule; the rest are
|
||||
// representative syntax/range probes.
|
||||
func Test_parsePort_invalid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErrContains string
|
||||
}{
|
||||
// Numeric overflow (the named bug + boundary).
|
||||
{"named bug: 2^32 truncates to PortAny", "4294967296", "out of range"},
|
||||
{"just above max real port", "65536", "out of range"},
|
||||
|
||||
// Negatives route through the range branch and hit the empty-half
|
||||
// guard; included as defense in depth so a future refactor cannot
|
||||
// accidentally reach the int32 cast.
|
||||
{"negative", "-1", "could not be parsed"},
|
||||
|
||||
// Syntax probes.
|
||||
{"NUL between digits", "4\x002", "was not a number"},
|
||||
{"hex notation", "0x10", "was not a number"},
|
||||
{"scientific notation", "1e3", "was not a number"},
|
||||
{"leading whitespace", " 42", "was not a number"},
|
||||
{"fullwidth digits", "42", "was not a number"},
|
||||
|
||||
// Range branch.
|
||||
{"range upper out of range", "1-65536", "ending range out of range"},
|
||||
{"range lower out of range", "65536-65537", "beginning range out of range"},
|
||||
{"range with negative upper", "1--1", "ending range was not a number"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, _, err := parsePort(tc.input)
|
||||
require.Error(t, err, "input %q must error", tc.input)
|
||||
require.ErrorContains(t, err, tc.wantErrContains)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test_parsePort_valid_boundaries locks in success cases at 0, 1, and 65535
|
||||
// so a future refactor cannot regress the boundaries.
|
||||
func Test_parsePort_valid_boundaries(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantStart int32
|
||||
wantEnd int32
|
||||
}{
|
||||
{"zero is PortAny", "0", 0, 0},
|
||||
{"min real port", "1", 1, 1},
|
||||
{"max real port", "65535", 65535, 65535},
|
||||
{"range zero to max forces end to zero", "0-65535", 0, 0},
|
||||
{"range max to max", "65535-65535", 65535, 65535},
|
||||
{"range one to max", "1-65535", 1, 65535},
|
||||
{"range with whitespace inside", " 1 - 2 ", 1, 2},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s, e, err := parsePort(tc.input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.wantStart, s, "start port")
|
||||
assert.Equal(t, tc.wantEnd, e, "end port")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFirewallFromConfig(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
// Test a bad rule definition
|
||||
|
||||
@@ -9,7 +9,7 @@ require (
|
||||
github.com/armon/go-radix v1.0.0
|
||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||
github.com/flynn/noise v1.1.0
|
||||
github.com/gaissmai/bart v0.26.0
|
||||
github.com/gaissmai/bart v0.27.1
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/kardianos/service v1.2.4
|
||||
@@ -24,12 +24,12 @@ require (
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
go.uber.org/goleak v1.3.0
|
||||
go.yaml.in/yaml/v3 v3.0.4
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/crypto v0.51.0
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||
golang.org/x/net v0.52.0
|
||||
golang.org/x/net v0.54.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.43.0
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/sys v0.44.0
|
||||
golang.org/x/term v0.43.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||
golang.zx2c4.com/wireguard/windows v0.6.1
|
||||
|
||||
@@ -26,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
|
||||
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||
github.com/gaissmai/bart v0.27.1 h1:FysPzqETMJa8q9rNkLW5peT1hq25nLOz8ksHbSVoiAk=
|
||||
github.com/gaissmai/bart v0.27.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
@@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
||||
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
|
||||
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
||||
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
|
||||
@@ -33,6 +33,7 @@ type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
|
||||
type Result struct {
|
||||
EKey *noise.CipherState
|
||||
DKey *noise.CipherState
|
||||
Cipher noise.CipherFunc // identifies which post-handshake CipherState the data plane should wrap EKey/DKey in
|
||||
MyCert cert.Certificate
|
||||
RemoteCert *cert.CachedCertificate
|
||||
RemoteIndex uint32
|
||||
@@ -114,6 +115,7 @@ func NewMachine(
|
||||
myVersion: version,
|
||||
result: &Result{
|
||||
Initiator: initiator,
|
||||
Cipher: cred.cipherSuite,
|
||||
},
|
||||
|
||||
multiport: multiport,
|
||||
|
||||
@@ -87,6 +87,7 @@ type HandshakeHostInfo struct {
|
||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
lastRelays []netip.Addr // Relays we attempted to use during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
|
||||
hostinfo *HostInfo
|
||||
@@ -221,7 +222,6 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
fields := []any{
|
||||
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
|
||||
"initiatorIndex", hh.hostinfo.localIndexId,
|
||||
"remoteIndex", hh.hostinfo.remoteIndexId,
|
||||
"durationNs", time.Since(hh.startTime).Nanoseconds(),
|
||||
}
|
||||
// hh.machine can be nil here if buildStage0Packet never succeeded
|
||||
@@ -352,7 +352,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
)
|
||||
}
|
||||
|
||||
hm.f.relayManager.StartRelays(hm.f, vpnIp, hostinfo, stage0)
|
||||
hm.f.relayManager.StartRelays(hm.f, vpnIp, hh, stage0)
|
||||
|
||||
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
|
||||
if !lighthouseTriggered {
|
||||
@@ -494,7 +494,6 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
|
||||
"remoteIndex", hostinfo.remoteIndexId,
|
||||
"collision", existingRemoteIndex.vpnAddrs,
|
||||
)
|
||||
}
|
||||
@@ -517,7 +516,6 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
|
||||
"remoteIndex", hostinfo.remoteIndexId,
|
||||
"collision", existingRemoteIndex.vpnAddrs,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -409,7 +409,6 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
"error", err,
|
||||
"udpAddr", remote,
|
||||
"counter", c,
|
||||
"attemptedCounter", c,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
+44
-25
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
@@ -389,13 +391,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
|
||||
}
|
||||
|
||||
func (f *Interface) reloadFirewall(c *config.C) {
|
||||
//TODO: need to trigger/detect if the certificate changed too
|
||||
if c.HasChanged("firewall") == false {
|
||||
cs := f.pki.getCertState()
|
||||
curCert := cs.getCertificate(cert.Version2)
|
||||
if curCert == nil {
|
||||
curCert = cs.getCertificate(cert.Version1)
|
||||
}
|
||||
|
||||
// The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction.
|
||||
// Check to see if that set has changed, and if so, rebuild the firewall.
|
||||
certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks)
|
||||
|
||||
if !c.HasChanged("firewall") && !certUnsafeChanged {
|
||||
f.l.Debug("No firewall config change detected")
|
||||
return
|
||||
}
|
||||
|
||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
||||
fw, err := NewFirewallFromConfig(f.l, cs, c)
|
||||
if err != nil {
|
||||
f.l.Error("Error while creating firewall during reload", "error", err)
|
||||
return
|
||||
@@ -507,33 +518,41 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
|
||||
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
|
||||
|
||||
emit := func() {
|
||||
f.firewall.EmitStats()
|
||||
f.handshakeManager.EmitStats()
|
||||
udpStats()
|
||||
|
||||
certState := f.pki.getCertState()
|
||||
defaultCrt := certState.GetDefaultCertificate()
|
||||
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
|
||||
certInitiatingVersion.Update(int64(defaultCrt.Version()))
|
||||
|
||||
if f.udpRaw != nil {
|
||||
if rawStats == nil {
|
||||
rawStats = udp.NewRawStatsEmitter(f.udpRaw)
|
||||
}
|
||||
rawStats()
|
||||
}
|
||||
|
||||
// Report the max certificate version we are capable of using
|
||||
if certState.v2Cert != nil {
|
||||
certMaxVersion.Update(int64(certState.v2Cert.Version()))
|
||||
} else {
|
||||
certMaxVersion.Update(int64(certState.v1Cert.Version()))
|
||||
}
|
||||
}
|
||||
|
||||
// Prime gauges so a Prometheus scrape that lands before the first tick
|
||||
// sees real values instead of the zero defaults (issue #907).
|
||||
emit()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
f.firewall.EmitStats()
|
||||
f.handshakeManager.EmitStats()
|
||||
udpStats()
|
||||
|
||||
certState := f.pki.getCertState()
|
||||
defaultCrt := certState.GetDefaultCertificate()
|
||||
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
|
||||
certInitiatingVersion.Update(int64(defaultCrt.Version()))
|
||||
|
||||
if f.udpRaw != nil {
|
||||
if rawStats == nil {
|
||||
rawStats = udp.NewRawStatsEmitter(f.udpRaw)
|
||||
}
|
||||
rawStats()
|
||||
}
|
||||
|
||||
// Report the max certificate version we are capable of using
|
||||
if certState.v2Cert != nil {
|
||||
certMaxVersion.Update(int64(certState.v2Cert.Version()))
|
||||
} else {
|
||||
certMaxVersion.Update(int64(certState.v1Cert.Version()))
|
||||
}
|
||||
emit()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/overlay/overlaytest"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that
|
||||
// landed before the first ticker fire used to read 0 for the cert gauges.
|
||||
// emitStats now primes the gauges before entering the ticker loop. We assert
|
||||
// the gauge is zero before the first call and non-zero after.
|
||||
func Test_emitStats_primesGauges(t *testing.T) {
|
||||
defer metrics.DefaultRegistry.UnregisterAll()
|
||||
|
||||
l := test.NewLogger()
|
||||
hostMap := newHostMap(l)
|
||||
preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
notAfter := time.Now().Add(time.Hour)
|
||||
cs := &CertState{
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter},
|
||||
v1Credential: nil,
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &overlaytest.NoopTun{},
|
||||
outside: &udp.NoopConn{},
|
||||
firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}},
|
||||
lightHouse: lh,
|
||||
pki: &PKI{},
|
||||
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
// On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to
|
||||
// *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn,
|
||||
// returns an error, and the emitter falls through to a no-op.
|
||||
writers: []udp.Conn{&udp.StdConn{}},
|
||||
}
|
||||
ifce.pki.cs.Store(cs)
|
||||
|
||||
ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
||||
require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs")
|
||||
|
||||
// Pre-cancel the context so emitStats returns after priming the gauges
|
||||
// without ever reading from ticker.C. The one hour interval is just a
|
||||
// belt-and-suspenders, the test does not expect the ticker to fire.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
ifce.emitStats(ctx, time.Hour)
|
||||
|
||||
ttl := ttlGauge.Value()
|
||||
assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick")
|
||||
assert.LessOrEqual(t, ttl, int64(3600))
|
||||
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value())
|
||||
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value())
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall
|
||||
// rebuilds the firewall when only the certificate's UnsafeNetworks have changed,
|
||||
// even if the firewall section of the YAML has not.
|
||||
func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
|
||||
initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
|
||||
|
||||
// dummyCert avoids dragging the real signing pipeline into a unit test.
|
||||
c1 := &dummyCert{
|
||||
version: cert.Version2,
|
||||
networks: []netip.Prefix{vpnNet},
|
||||
unsafeNetworks: initialUnsafe,
|
||||
}
|
||||
pki := &PKI{}
|
||||
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
|
||||
|
||||
rawYAML := `firewall:
|
||||
outbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
inbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
`
|
||||
cfg := config.NewC(l)
|
||||
require.NoError(t, cfg.LoadString(rawYAML))
|
||||
|
||||
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialUnsafe, fw.unsafeNetworks)
|
||||
|
||||
f := &Interface{
|
||||
pki: pki,
|
||||
firewall: fw,
|
||||
l: l,
|
||||
}
|
||||
|
||||
// Swap the cert with a different UnsafeNetworks set.
|
||||
newUnsafe := []netip.Prefix{
|
||||
netip.MustParsePrefix("198.51.100.0/24"),
|
||||
netip.MustParsePrefix("203.0.113.0/24"),
|
||||
}
|
||||
c2 := &dummyCert{
|
||||
version: cert.Version2,
|
||||
networks: []netip.Prefix{vpnNet},
|
||||
unsafeNetworks: newUnsafe,
|
||||
}
|
||||
pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2})
|
||||
|
||||
// Reload with the same YAML so HasChanged("firewall") reports false.
|
||||
require.NoError(t, cfg.ReloadConfigString(rawYAML))
|
||||
require.False(t, cfg.HasChanged("firewall"))
|
||||
|
||||
f.reloadFirewall(cfg)
|
||||
|
||||
assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced")
|
||||
assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks)
|
||||
assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5")))
|
||||
}
|
||||
|
||||
// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when
|
||||
// neither the firewall config nor the cert's UnsafeNetworks have changed.
|
||||
func TestReloadFirewall_NoChange(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
|
||||
unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
|
||||
|
||||
c1 := &dummyCert{
|
||||
version: cert.Version2,
|
||||
networks: []netip.Prefix{vpnNet},
|
||||
unsafeNetworks: unsafe,
|
||||
}
|
||||
pki := &PKI{}
|
||||
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
|
||||
|
||||
rawYAML := `firewall:
|
||||
outbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
inbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
`
|
||||
cfg := config.NewC(l)
|
||||
require.NoError(t, cfg.LoadString(rawYAML))
|
||||
|
||||
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
f := &Interface{
|
||||
pki: pki,
|
||||
firewall: fw,
|
||||
l: l,
|
||||
}
|
||||
|
||||
require.NoError(t, cfg.ReloadConfigString(rawYAML))
|
||||
f.reloadFirewall(cfg)
|
||||
|
||||
assert.Same(t, fw, f.firewall, "firewall should not have been replaced")
|
||||
}
|
||||
+27
-50
@@ -15,7 +15,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
@@ -35,7 +34,6 @@ type LightHouse struct {
|
||||
|
||||
myVpnNetworks []netip.Prefix
|
||||
myVpnNetworksTable *bart.Lite
|
||||
punchConn udp.Conn
|
||||
punchy *Punchy
|
||||
|
||||
// Local cache of answers from light houses
|
||||
@@ -75,9 +73,8 @@ type LightHouse struct {
|
||||
|
||||
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
|
||||
|
||||
metrics *MessageMetrics
|
||||
metricHolepunchTx metrics.Counter
|
||||
l *slog.Logger
|
||||
metrics *MessageMetrics
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
||||
@@ -105,7 +102,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
|
||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||
addrMap: make(map[netip.Addr]*RemoteList),
|
||||
nebulaPort: nebulaPort,
|
||||
punchConn: pc,
|
||||
punchy: p,
|
||||
updateTrigger: make(chan struct{}, 1),
|
||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||
@@ -118,9 +114,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
|
||||
|
||||
if c.GetBool("stats.lighthouse_metrics", false) {
|
||||
h.metrics = newLighthouseMetrics()
|
||||
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
|
||||
} else {
|
||||
h.metricHolepunchTx = metrics.NilCounter{}
|
||||
}
|
||||
|
||||
err := h.reload(c, true)
|
||||
@@ -279,16 +272,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
|
||||
if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
|
||||
// Clean up. Entries still in the static_host_map will be re-built.
|
||||
// Entries no longer present must have their (possible) background DNS goroutines stopped.
|
||||
if existingStaticList := lh.staticList.Load(); existingStaticList != nil {
|
||||
ourselves := lh.myVpnNetworks[0].Addr()
|
||||
oldStaticList := lh.staticList.Load()
|
||||
if oldStaticList != nil {
|
||||
lh.RLock()
|
||||
for staticVpnAddr := range *existingStaticList {
|
||||
for staticVpnAddr := range *oldStaticList {
|
||||
if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil {
|
||||
am.hr.Cancel()
|
||||
am.ResetForOwner(ourselves)
|
||||
}
|
||||
}
|
||||
lh.RUnlock()
|
||||
}
|
||||
|
||||
// Build a new list based on current config.
|
||||
staticList := make(map[netip.Addr]struct{})
|
||||
err := lh.loadStaticMap(c, staticList)
|
||||
@@ -296,6 +291,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// For entries removed from static_host_map, stop the DNS goroutine and drop the cached addrs.
|
||||
// All addrs must come from the lighthouses now that it's no longer a static host.
|
||||
if oldStaticList != nil {
|
||||
lh.RLock()
|
||||
for staticVpnAddr := range *oldStaticList {
|
||||
if _, stillStatic := staticList[staticVpnAddr]; stillStatic {
|
||||
continue
|
||||
}
|
||||
if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil {
|
||||
am.ClearHostnameResults()
|
||||
}
|
||||
}
|
||||
lh.RUnlock()
|
||||
}
|
||||
|
||||
lh.staticList.Store(&staticList)
|
||||
if !initial {
|
||||
if c.HasChanged("static_host_map") {
|
||||
@@ -1406,58 +1416,25 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
return
|
||||
}
|
||||
|
||||
empty := []byte{0}
|
||||
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
|
||||
if !vpnPeer.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(lhh.lh.punchy.GetDelay())
|
||||
lhh.lh.metricHolepunchTx.Inc(1)
|
||||
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
||||
}()
|
||||
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("Punching",
|
||||
"vpnPeer", vpnPeer,
|
||||
"logVpnAddr", logVpnAddr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
||||
for _, a := range n.Details.V4AddrPorts {
|
||||
b := protoV4AddrPortToNetAddrPort(a)
|
||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||
punch(b, detailsVpnAddr)
|
||||
lhh.lh.punchy.Schedule(b, detailsVpnAddr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range n.Details.V6AddrPorts {
|
||||
b := protoV6AddrPortToNetAddrPort(a)
|
||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||
punch(b, detailsVpnAddr)
|
||||
lhh.lh.punchy.Schedule(b, detailsVpnAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||
// of a double nat or other difficult scenario, this may help establish
|
||||
// a tunnel.
|
||||
if lhh.lh.punchy.GetRespond() {
|
||||
go func() {
|
||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("Sending a nebula test packet",
|
||||
"vpnAddr", detailsVpnAddr,
|
||||
)
|
||||
}
|
||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||
// managed by a channel.
|
||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
}()
|
||||
}
|
||||
// a tunnel. ScheduleRespond is a no-op when punchy.respond is disabled.
|
||||
lhh.lh.punchy.ScheduleRespond(detailsVpnAddr)
|
||||
}
|
||||
|
||||
func protoAddrToNetAddr(addr *Addr) netip.Addr {
|
||||
|
||||
@@ -303,6 +303,132 @@ func TestLighthouse_reload(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestLighthouse_reloadStaticHostMap verifies that reloading static_host_map applies the new
|
||||
// config rather than appending to it. See issue #718.
|
||||
func TestLighthouse_reloadStaticHostMap(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||
c.Settings["static_host_map"] = map[string]any{
|
||||
"10.128.0.2": []any{"1.1.1.1:4242"},
|
||||
}
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
|
||||
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
staticHost := netip.MustParseAddr("10.128.0.2")
|
||||
otherHost := netip.MustParseAddr("10.128.0.3")
|
||||
|
||||
// Capture the RemoteList pointer up front; an in-flight handshake would hold the same one
|
||||
// on hostinfo.remotes, so it must reflect every reload below.
|
||||
pinned := lh.Query(staticHost)
|
||||
require.NotNil(t, pinned)
|
||||
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, pinned.CopyAddrs([]netip.Prefix{}))
|
||||
|
||||
// Replace the remote address. The new address should be the only entry.
|
||||
nc := map[string]any{
|
||||
"static_host_map": map[string]any{
|
||||
"10.128.0.2": []any{"2.2.2.2:4242"},
|
||||
},
|
||||
}
|
||||
rc, err := yaml.Marshal(nc)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.ReloadConfigString(string(rc)))
|
||||
|
||||
rl := lh.Query(staticHost)
|
||||
require.NotNil(t, rl)
|
||||
assert.Same(t, pinned, rl, "RemoteList pointer must stay stable so in-flight handshakes pick up the change")
|
||||
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{}))
|
||||
|
||||
// Reload back to the original IP. Mirrors the round-trip in issue #718 step 6-8 where
|
||||
// the buggy reload produced [1.1.1.1, 2.2.2.2, 1.1.1.1] instead of [1.1.1.1].
|
||||
nc = map[string]any{
|
||||
"static_host_map": map[string]any{
|
||||
"10.128.0.2": []any{"1.1.1.1:4242"},
|
||||
},
|
||||
}
|
||||
rc, err = yaml.Marshal(nc)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.ReloadConfigString(string(rc)))
|
||||
|
||||
rl = lh.Query(staticHost)
|
||||
require.NotNil(t, rl)
|
||||
assert.Same(t, pinned, rl)
|
||||
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{}))
|
||||
|
||||
// Reload with the same config. An unchanged entry must not duplicate.
|
||||
require.NoError(t, c.ReloadConfigString(string(rc)))
|
||||
|
||||
rl = lh.Query(staticHost)
|
||||
require.NotNil(t, rl)
|
||||
assert.Same(t, pinned, rl)
|
||||
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{}))
|
||||
|
||||
// Switch back to 2.2.2.2 so the rest of the test continues against a known address.
|
||||
nc = map[string]any{
|
||||
"static_host_map": map[string]any{
|
||||
"10.128.0.2": []any{"2.2.2.2:4242"},
|
||||
},
|
||||
}
|
||||
rc, err = yaml.Marshal(nc)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.ReloadConfigString(string(rc)))
|
||||
|
||||
// Add a second host alongside the first. Both should be present, neither duplicated.
|
||||
nc = map[string]any{
|
||||
"static_host_map": map[string]any{
|
||||
"10.128.0.2": []any{"2.2.2.2:4242"},
|
||||
"10.128.0.3": []any{"3.3.3.3:4242"},
|
||||
},
|
||||
}
|
||||
rc, err = yaml.Marshal(nc)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.ReloadConfigString(string(rc)))
|
||||
|
||||
rl = lh.Query(staticHost)
|
||||
require.NotNil(t, rl)
|
||||
assert.Same(t, pinned, rl, "adding a sibling entry must not displace the existing RemoteList")
|
||||
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{}))
|
||||
|
||||
rl = lh.Query(otherHost)
|
||||
require.NotNil(t, rl)
|
||||
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{}))
|
||||
|
||||
// Drop the first host entirely. The vpnAddr is no longer marked static, our owner
|
||||
// contribution is cleared, but the addrMap entry stays in place so non-static cache
|
||||
// data (from lighthouse queries) on the same RemoteList isn't lost. In-flight handshakes
|
||||
// that already had the pointer see an empty address list rather than retrying stale ones.
|
||||
nc = map[string]any{
|
||||
"static_host_map": map[string]any{
|
||||
"10.128.0.3": []any{"3.3.3.3:4242"},
|
||||
},
|
||||
}
|
||||
rc, err = yaml.Marshal(nc)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.ReloadConfigString(string(rc)))
|
||||
|
||||
_, isStatic := lh.GetStaticHostList()[staticHost]
|
||||
assert.False(t, isStatic)
|
||||
|
||||
rl = lh.Query(staticHost)
|
||||
require.NotNil(t, rl)
|
||||
assert.Same(t, pinned, rl)
|
||||
assert.Empty(t, rl.CopyAddrs([]netip.Prefix{}))
|
||||
|
||||
rl = lh.Query(otherHost)
|
||||
require.NotNil(t, rl)
|
||||
assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{}))
|
||||
}
|
||||
|
||||
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
|
||||
req := &NebulaMeta{
|
||||
Type: NebulaMeta_HostQuery,
|
||||
|
||||
@@ -55,7 +55,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
|
||||
}
|
||||
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
|
||||
ssh, err := sshd.NewSSHServer(ctx, l.With("subsystem", "sshd"))
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
|
||||
}
|
||||
|
||||
hostMap := NewHostMapFromConfig(l, c)
|
||||
punchy := NewPunchyFromConfig(l, c)
|
||||
punchy := NewPunchyFromConfig(l, c, udpConns[0])
|
||||
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
|
||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
||||
if err != nil {
|
||||
@@ -194,7 +194,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
|
||||
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||
|
||||
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
|
||||
ds, err := newDnsServerFromConfig(ctx, l, pki, hostMap, c)
|
||||
if err != nil {
|
||||
l.Warn("Failed to start DNS responder", "error", err)
|
||||
}
|
||||
@@ -273,6 +273,8 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
|
||||
|
||||
handshakeManager.f = ifce
|
||||
go handshakeManager.Run(ctx)
|
||||
|
||||
punchy.Start(ctx, ifce, hostMap, lightHouse)
|
||||
}
|
||||
|
||||
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
)
|
||||
|
||||
type endianness interface {
|
||||
PutUint64(b []byte, v uint64)
|
||||
}
|
||||
|
||||
var noiseEndianness endianness = binary.BigEndian
|
||||
|
||||
type NebulaCipherState struct {
|
||||
c cipher.AEAD
|
||||
}
|
||||
|
||||
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
|
||||
x := s.Cipher()
|
||||
return &NebulaCipherState{c: x.(cipher.AEAD)}
|
||||
}
|
||||
|
||||
// EncryptDanger encrypts and authenticates a given payload.
|
||||
//
|
||||
// out is a destination slice to hold the output of the EncryptDanger operation.
|
||||
// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
|
||||
// - plaintext is encrypted, authenticated and appended to out.
|
||||
// - n is a nonce value which must never be re-used with this key.
|
||||
// - nb is a buffer used for temporary storage in the implementation of this call, which should
|
||||
// be re-used by callers to minimize garbage collection.
|
||||
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||
if s != nil {
|
||||
// TODO: Is this okay now that we have made messageCounter atomic?
|
||||
// Alternative may be to split the counter space into ranges
|
||||
//if n <= s.n {
|
||||
// return nil, errors.New("CRITICAL: a duplicate counter value was used")
|
||||
//}
|
||||
//s.n = n
|
||||
nb[0] = 0
|
||||
nb[1] = 0
|
||||
nb[2] = 0
|
||||
nb[3] = 0
|
||||
noiseEndianness.PutUint64(nb[4:], n)
|
||||
out = s.c.Seal(out, nb, plaintext, ad)
|
||||
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
|
||||
return out, nil
|
||||
} else {
|
||||
return nil, errors.New("no cipher state available to encrypt")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||
if s != nil {
|
||||
nb[0] = 0
|
||||
nb[1] = 0
|
||||
nb[2] = 0
|
||||
nb[3] = 0
|
||||
noiseEndianness.PutUint64(nb[4:], n)
|
||||
return s.c.Open(out, nb, ciphertext, ad)
|
||||
} else {
|
||||
return []byte{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NebulaCipherState) Overhead() int {
|
||||
if s != nil {
|
||||
return s.c.Overhead()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package noiseutil
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
)
|
||||
|
||||
// CipherStateAESGCM is the data-plane wrapper for the AES-GCM AEAD cipher.
|
||||
// AES-GCM uses big-endian nonce encoding per the Noise spec.
|
||||
type CipherStateAESGCM struct {
|
||||
c cipher.AEAD
|
||||
}
|
||||
|
||||
// NewCipherStateAESGCM extracts the underlying AEAD from the post-handshake noise.CipherState.
|
||||
// The caller is responsible for ensuring the noise cipher is actually AES-GCM,
|
||||
// otherwise the type assertion still succeeds but the nonce endianness will be wrong on the wire.
|
||||
func NewCipherStateAESGCM(s *noise.CipherState) *CipherStateAESGCM {
|
||||
return &CipherStateAESGCM{c: s.Cipher().(cipher.AEAD)}
|
||||
}
|
||||
|
||||
func (s *CipherStateAESGCM) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("no cipher state available to encrypt")
|
||||
}
|
||||
nb[0] = 0
|
||||
nb[1] = 0
|
||||
nb[2] = 0
|
||||
nb[3] = 0
|
||||
binary.BigEndian.PutUint64(nb[4:], n)
|
||||
return s.c.Seal(out, nb, plaintext, ad), nil
|
||||
}
|
||||
|
||||
func (s *CipherStateAESGCM) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||
if s == nil {
|
||||
return []byte{}, nil
|
||||
}
|
||||
nb[0] = 0
|
||||
nb[1] = 0
|
||||
nb[2] = 0
|
||||
nb[3] = 0
|
||||
binary.BigEndian.PutUint64(nb[4:], n)
|
||||
return s.c.Open(out, nb, ciphertext, ad)
|
||||
}
|
||||
|
||||
func (s *CipherStateAESGCM) Overhead() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return s.c.Overhead()
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package noiseutil
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
)
|
||||
|
||||
// CipherStateChaChaPoly is the data-plane wrapper for the ChaCha20-Poly1305 AEAD cipher.
|
||||
// ChaCha20-Poly1305 uses little-endian nonce encoding per the Noise spec.
|
||||
type CipherStateChaChaPoly struct {
|
||||
c cipher.AEAD
|
||||
}
|
||||
|
||||
// NewCipherStateChaChaPoly extracts the underlying AEAD from the post-handshake noise.CipherState.
|
||||
// The caller is responsible for ensuring the noise cipher is actually ChaCha20-Poly1305.
|
||||
func NewCipherStateChaChaPoly(s *noise.CipherState) *CipherStateChaChaPoly {
|
||||
return &CipherStateChaChaPoly{c: s.Cipher().(cipher.AEAD)}
|
||||
}
|
||||
|
||||
func (s *CipherStateChaChaPoly) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("no cipher state available to encrypt")
|
||||
}
|
||||
nb[0] = 0
|
||||
nb[1] = 0
|
||||
nb[2] = 0
|
||||
nb[3] = 0
|
||||
binary.LittleEndian.PutUint64(nb[4:], n)
|
||||
return s.c.Seal(out, nb, plaintext, ad), nil
|
||||
}
|
||||
|
||||
func (s *CipherStateChaChaPoly) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||
if s == nil {
|
||||
return []byte{}, nil
|
||||
}
|
||||
nb[0] = 0
|
||||
nb[1] = 0
|
||||
nb[2] = 0
|
||||
nb[3] = 0
|
||||
binary.LittleEndian.PutUint64(nb[4:], n)
|
||||
return s.c.Open(out, nb, ciphertext, ad)
|
||||
}
|
||||
|
||||
func (s *CipherStateChaChaPoly) Overhead() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return s.c.Overhead()
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package noiseutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
)
|
||||
|
||||
// CipherState is the post-handshake AEAD cipher used for the data plane.
|
||||
// Each supported cipher has its own concrete implementation in this package with the nonce endianness hardcoded,
|
||||
// so the encrypt/decrypt fast path avoids interface dispatch on the byte order.
|
||||
type CipherState interface {
|
||||
// EncryptDanger encrypts and authenticates a given payload.
|
||||
//
|
||||
// out is a destination slice to hold the output of the EncryptDanger operation.
|
||||
// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
|
||||
// - plaintext is encrypted, authenticated and appended to out.
|
||||
// - n is a nonce value which must never be re-used with this key.
|
||||
// - nb is a scratch buffer used to assemble the nonce.
|
||||
EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error)
|
||||
|
||||
// DecryptDanger authenticates and decrypts a given payload, with the same argument shape as EncryptDanger.
|
||||
DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error)
|
||||
|
||||
// Overhead returns the AEAD tag size, or 0 if the receiver is nil.
|
||||
Overhead() int
|
||||
}
|
||||
|
||||
// NewCipherState wraps the post-handshake noise.CipherState in the per-cipher type that matches cipherFunc.
|
||||
// cipherFunc must be the same cipher used to build the noise CipherSuite that produced s.
|
||||
func NewCipherState(s *noise.CipherState, cipherFunc noise.CipherFunc) CipherState {
|
||||
switch cipherFunc.CipherName() {
|
||||
case CipherAESGCM.CipherName():
|
||||
return NewCipherStateAESGCM(s)
|
||||
case noise.CipherChaChaPoly.CipherName():
|
||||
return NewCipherStateChaChaPoly(s)
|
||||
default:
|
||||
panic(fmt.Sprintf("noiseutil: unsupported cipher %q", cipherFunc.CipherName()))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package noiseutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCipherStateAESGCMRoundtrip(t *testing.T) {
|
||||
enc, dec := buildCipherStates(t, CipherAESGCM)
|
||||
roundtrip(t, NewCipherStateAESGCM(enc), NewCipherStateAESGCM(dec))
|
||||
}
|
||||
|
||||
func TestCipherStateChaChaPolyRoundtrip(t *testing.T) {
|
||||
enc, dec := buildCipherStates(t, noise.CipherChaChaPoly)
|
||||
roundtrip(t, NewCipherStateChaChaPoly(enc), NewCipherStateChaChaPoly(dec))
|
||||
}
|
||||
|
||||
func TestNewCipherStateDispatch(t *testing.T) {
|
||||
encA, _ := buildCipherStates(t, CipherAESGCM)
|
||||
encC, _ := buildCipherStates(t, noise.CipherChaChaPoly)
|
||||
|
||||
assert.IsType(t, &CipherStateAESGCM{}, NewCipherState(encA, CipherAESGCM))
|
||||
assert.IsType(t, &CipherStateChaChaPoly{}, NewCipherState(encC, noise.CipherChaChaPoly))
|
||||
}
|
||||
|
||||
func TestNewCipherStateUnsupportedPanics(t *testing.T) {
|
||||
enc, _ := buildCipherStates(t, CipherAESGCM)
|
||||
assert.Panics(t, func() {
|
||||
NewCipherState(enc, fakeCipher{})
|
||||
})
|
||||
}
|
||||
|
||||
type fakeCipher struct{}
|
||||
|
||||
func (fakeCipher) Cipher(k [32]byte) noise.Cipher { return nil }
|
||||
func (fakeCipher) CipherName() string { return "Fake" }
|
||||
|
||||
// buildCipherStates runs an in-memory NN handshake with the requested cipher
|
||||
// to produce a pair of post-handshake CipherStates that share keys.
|
||||
func buildCipherStates(t *testing.T, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) {
|
||||
t.Helper()
|
||||
suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256)
|
||||
cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN}
|
||||
cfg.Initiator = true
|
||||
hsI, err := noise.NewHandshakeState(cfg)
|
||||
require.NoError(t, err)
|
||||
cfg.Initiator = false
|
||||
hsR, err := noise.NewHandshakeState(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg, _, _, err := hsI.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
_, _, _, err = hsR.ReadMessage(nil, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg, dR, _, err := hsR.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
_, eI, _, err := hsI.ReadMessage(nil, msg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, eI)
|
||||
require.NotNil(t, dR)
|
||||
|
||||
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
|
||||
return eI, dR
|
||||
}
|
||||
|
||||
func roundtrip(t *testing.T, enc, dec CipherState) {
|
||||
t.Helper()
|
||||
plaintext := []byte("nebula cipher state roundtrip")
|
||||
ad := []byte("aad")
|
||||
nb := make([]byte, 12)
|
||||
|
||||
ct, err := enc.EncryptDanger(nil, ad, plaintext, 1, nb)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, plaintext, ct)
|
||||
|
||||
pt, err := dec.DecryptDanger(nil, ad, ct, 1, nb)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, pt)
|
||||
|
||||
// Wrong nonce must fail authentication.
|
||||
_, err = dec.DecryptDanger(nil, ad, ct, 2, nb)
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, enc.Overhead(), dec.Overhead())
|
||||
assert.Equal(t, 16, enc.Overhead())
|
||||
}
|
||||
|
||||
func BenchmarkCipherStateEncryptAESGCM(b *testing.B) {
|
||||
enc, _ := buildCipherStatesB(b, CipherAESGCM)
|
||||
benchEncryptCipherState(b, NewCipherState(enc, CipherAESGCM))
|
||||
}
|
||||
|
||||
func BenchmarkCipherStateEncryptChaChaPoly(b *testing.B) {
|
||||
enc, _ := buildCipherStatesB(b, noise.CipherChaChaPoly)
|
||||
benchEncryptCipherState(b, NewCipherState(enc, noise.CipherChaChaPoly))
|
||||
}
|
||||
|
||||
func benchEncryptCipherState(b *testing.B, cs CipherState) {
|
||||
plaintext := make([]byte, 1280)
|
||||
ad := make([]byte, 16)
|
||||
nb := make([]byte, 12)
|
||||
out := make([]byte, 0, len(plaintext)+cs.Overhead())
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var err error
|
||||
out, err = cs.EncryptDanger(out[:0], ad, plaintext, uint64(i+1), nb)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildCipherStatesB(b *testing.B, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) {
|
||||
b.Helper()
|
||||
suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256)
|
||||
cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN}
|
||||
cfg.Initiator = true
|
||||
hsI, err := noise.NewHandshakeState(cfg)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
cfg.Initiator = false
|
||||
hsR, err := noise.NewHandshakeState(cfg)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
msg, _, _, err := hsI.WriteMessage(nil, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, _, _, err := hsR.ReadMessage(nil, msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
msg, dR, _, err := hsR.WriteMessage(nil, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
_, eI, _, err := hsI.ReadMessage(nil, msg)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
return eI, dR
|
||||
}
|
||||
|
||||
func TestCipherStateNilSafety(t *testing.T) {
|
||||
var aes *CipherStateAESGCM
|
||||
_, err := aes.EncryptDanger(nil, nil, nil, 0, make([]byte, 12))
|
||||
require.Error(t, err)
|
||||
out, err := aes.DecryptDanger(nil, nil, nil, 0, make([]byte, 12))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, out)
|
||||
assert.Equal(t, 0, aes.Overhead())
|
||||
|
||||
var cc *CipherStateChaChaPoly
|
||||
_, err = cc.EncryptDanger(nil, nil, nil, 0, make([]byte, 12))
|
||||
require.Error(t, err)
|
||||
out, err = cc.DecryptDanger(nil, nil, nil, 0, make([]byte, 12))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, out)
|
||||
assert.Equal(t, 0, cc.Overhead())
|
||||
}
|
||||
+2
-3
@@ -194,8 +194,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender,
|
||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
||||
// its internal mapping. This should never happen.
|
||||
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"relayRemoteIndex", h.RemoteIndex,
|
||||
)
|
||||
return
|
||||
}
|
||||
@@ -218,8 +217,8 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender,
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
|
||||
"relayTo", relay.PeerAddr,
|
||||
"relayFrom", hostinfo.vpnAddrs[0],
|
||||
"error", err,
|
||||
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,358 @@
|
||||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// networkCategory mirrors NLM_NETWORK_CATEGORY from netlistmgr.h.
|
||||
type networkCategory int32
|
||||
|
||||
const (
|
||||
networkCategoryPublic networkCategory = 0
|
||||
networkCategoryPrivate networkCategory = 1
|
||||
networkCategoryDomainAuthenticated networkCategory = 2
|
||||
)
|
||||
|
||||
func (c networkCategory) String() string {
|
||||
switch c {
|
||||
case networkCategoryPublic:
|
||||
return "public"
|
||||
case networkCategoryPrivate:
|
||||
return "private"
|
||||
case networkCategoryDomainAuthenticated:
|
||||
return "domain"
|
||||
}
|
||||
return fmt.Sprintf("unknown(%d)", c)
|
||||
}
|
||||
|
||||
// parseNetworkCategory accepts the user-supplied tun.network_category. A
|
||||
// second return of false means "leave the category alone".
|
||||
func parseNetworkCategory(s string) (networkCategory, bool, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "", "unset":
|
||||
return 0, false, nil
|
||||
case "public":
|
||||
return networkCategoryPublic, true, nil
|
||||
case "private":
|
||||
return networkCategoryPrivate, true, nil
|
||||
case "domain", "domainauthenticated":
|
||||
return networkCategoryDomainAuthenticated, true, nil
|
||||
}
|
||||
return 0, false, fmt.Errorf("unknown tun.network_category %q (expected public, private, domain, or unset)", s)
|
||||
}
|
||||
|
||||
// CLSID_NetworkListManager {DCB00C01-570F-4A9B-8D69-199FDBA5723B}
|
||||
var clsidNetworkListManager = windows.GUID{
|
||||
Data1: 0xDCB00C01, Data2: 0x570F, Data3: 0x4A9B,
|
||||
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
|
||||
}
|
||||
|
||||
// IID_INetworkListManager {DCB00000-570F-4A9B-8D69-199FDBA5723B}
|
||||
var iidINetworkListManager = windows.GUID{
|
||||
Data1: 0xDCB00000, Data2: 0x570F, Data3: 0x4A9B,
|
||||
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
|
||||
}
|
||||
|
||||
// x/sys/windows doesn't expose CoCreateInstance, so we bind it ourselves.
|
||||
var procCoCreateInstance = windows.NewLazySystemDLL("ole32.dll").NewProc("CoCreateInstance")
|
||||
|
||||
const clsCtxAll = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER |
|
||||
windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER
|
||||
|
||||
const (
|
||||
hrSFALSE = 0x00000001
|
||||
hrRPCEChangedMode = 0x80010106
|
||||
)
|
||||
|
||||
type hresult uint32
|
||||
|
||||
func (h hresult) failed() bool { return int32(h) < 0 }
|
||||
func (h hresult) String() string {
|
||||
return fmt.Sprintf("HRESULT 0x%08x", uint32(h))
|
||||
}
|
||||
|
||||
var errAdapterNotFound = errors.New("adapter not present in network connections enumeration")
|
||||
|
||||
// Vtable layouts. Slot order must match the declaration order in netlistmgr.h.
|
||||
// All NLM interfaces here derive from IDispatch, which derives from IUnknown.
|
||||
|
||||
type iUnknownVtbl struct {
|
||||
QueryInterface uintptr
|
||||
AddRef uintptr
|
||||
Release uintptr
|
||||
}
|
||||
|
||||
type iDispatchVtbl struct {
|
||||
iUnknownVtbl
|
||||
GetTypeInfoCount uintptr
|
||||
GetTypeInfo uintptr
|
||||
GetIDsOfNames uintptr
|
||||
Invoke uintptr
|
||||
}
|
||||
|
||||
type iNetworkListManagerVtbl struct {
|
||||
iDispatchVtbl
|
||||
GetNetworks uintptr
|
||||
GetNetwork uintptr
|
||||
GetNetworkConnections uintptr
|
||||
GetNetworkConnection uintptr
|
||||
IsConnectedToInternet uintptr
|
||||
IsConnected uintptr
|
||||
GetConnectivity uintptr
|
||||
}
|
||||
|
||||
type iNetworkListManager struct{ Vtbl *iNetworkListManagerVtbl }
|
||||
|
||||
func (n *iNetworkListManager) Release() {
|
||||
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
|
||||
}
|
||||
|
||||
func (n *iNetworkListManager) GetNetworkConnections() (*iEnumNetworkConnections, error) {
|
||||
var enum *iEnumNetworkConnections
|
||||
r1, _, _ := syscall.SyscallN(n.Vtbl.GetNetworkConnections,
|
||||
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&enum)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return nil, fmt.Errorf("INetworkListManager.GetNetworkConnections: %s", hr)
|
||||
}
|
||||
return enum, nil
|
||||
}
|
||||
|
||||
type iEnumNetworkConnectionsVtbl struct {
|
||||
iDispatchVtbl
|
||||
NewEnum uintptr
|
||||
Next uintptr
|
||||
Skip uintptr
|
||||
Reset uintptr
|
||||
Clone uintptr
|
||||
}
|
||||
|
||||
type iEnumNetworkConnections struct{ Vtbl *iEnumNetworkConnectionsVtbl }
|
||||
|
||||
func (e *iEnumNetworkConnections) Release() {
|
||||
syscall.SyscallN(e.Vtbl.Release, uintptr(unsafe.Pointer(e)))
|
||||
}
|
||||
|
||||
// Next returns the next connection, or (nil, nil) at the end of the enumeration.
|
||||
func (e *iEnumNetworkConnections) Next() (*iNetworkConnection, error) {
|
||||
var conn *iNetworkConnection
|
||||
var fetched uint32
|
||||
r1, _, _ := syscall.SyscallN(e.Vtbl.Next,
|
||||
uintptr(unsafe.Pointer(e)), 1,
|
||||
uintptr(unsafe.Pointer(&conn)), uintptr(unsafe.Pointer(&fetched)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return nil, fmt.Errorf("IEnumNetworkConnections.Next: %s", hr)
|
||||
}
|
||||
if fetched == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type iNetworkConnectionVtbl struct {
|
||||
iDispatchVtbl
|
||||
GetNetwork uintptr
|
||||
IsConnectedToInternet uintptr
|
||||
IsConnected uintptr
|
||||
GetConnectivity uintptr
|
||||
GetConnectionId uintptr
|
||||
GetAdapterId uintptr
|
||||
GetDomainType uintptr
|
||||
}
|
||||
|
||||
type iNetworkConnection struct{ Vtbl *iNetworkConnectionVtbl }
|
||||
|
||||
func (c *iNetworkConnection) Release() {
|
||||
syscall.SyscallN(c.Vtbl.Release, uintptr(unsafe.Pointer(c)))
|
||||
}
|
||||
|
||||
func (c *iNetworkConnection) GetAdapterId() (windows.GUID, error) {
|
||||
var g windows.GUID
|
||||
r1, _, _ := syscall.SyscallN(c.Vtbl.GetAdapterId,
|
||||
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&g)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return windows.GUID{}, fmt.Errorf("INetworkConnection.GetAdapterId: %s", hr)
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (c *iNetworkConnection) GetNetwork() (*iNetwork, error) {
|
||||
var net *iNetwork
|
||||
r1, _, _ := syscall.SyscallN(c.Vtbl.GetNetwork,
|
||||
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&net)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return nil, fmt.Errorf("INetworkConnection.GetNetwork: %s", hr)
|
||||
}
|
||||
return net, nil
|
||||
}
|
||||
|
||||
type iNetworkVtbl struct {
|
||||
iDispatchVtbl
|
||||
GetName uintptr
|
||||
SetName uintptr
|
||||
GetDescription uintptr
|
||||
SetDescription uintptr
|
||||
GetNetworkId uintptr
|
||||
GetDomainType uintptr
|
||||
GetNetworkConnections uintptr
|
||||
GetTimeCreatedAndConnected uintptr
|
||||
IsConnectedToInternet uintptr
|
||||
IsConnected uintptr
|
||||
GetConnectivity uintptr
|
||||
GetCategory uintptr
|
||||
SetCategory uintptr
|
||||
}
|
||||
|
||||
type iNetwork struct{ Vtbl *iNetworkVtbl }
|
||||
|
||||
func (n *iNetwork) Release() {
|
||||
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
|
||||
}
|
||||
|
||||
func (n *iNetwork) GetCategory() (networkCategory, error) {
|
||||
var c networkCategory
|
||||
r1, _, _ := syscall.SyscallN(n.Vtbl.GetCategory,
|
||||
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&c)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return 0, fmt.Errorf("INetwork.GetCategory: %s", hr)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (n *iNetwork) SetCategory(c networkCategory) error {
|
||||
r1, _, _ := syscall.SyscallN(n.Vtbl.SetCategory,
|
||||
uintptr(unsafe.Pointer(n)), uintptr(int32(c)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return fmt.Errorf("INetwork.SetCategory: %s", hr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// coInit initializes COM for the current OS thread. The returned function must
|
||||
// be deferred to balance a successful init. RPC_E_CHANGED_MODE means COM is
|
||||
// already initialized in a different mode on this thread, which is still fine
|
||||
// for our calls but we must not Uninitialize in that case.
|
||||
func coInit() (func(), error) {
|
||||
err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED)
|
||||
if err == nil {
|
||||
return windows.CoUninitialize, nil
|
||||
}
|
||||
if e, ok := err.(syscall.Errno); ok {
|
||||
switch uint32(e) {
|
||||
case hrSFALSE:
|
||||
return windows.CoUninitialize, nil
|
||||
case hrRPCEChangedMode:
|
||||
return func() {}, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("CoInitializeEx: %w", err)
|
||||
}
|
||||
|
||||
func createNetworkListManager() (*iNetworkListManager, error) {
|
||||
var nlm *iNetworkListManager
|
||||
r1, _, _ := procCoCreateInstance.Call(
|
||||
uintptr(unsafe.Pointer(&clsidNetworkListManager)),
|
||||
0,
|
||||
uintptr(clsCtxAll),
|
||||
uintptr(unsafe.Pointer(&iidINetworkListManager)),
|
||||
uintptr(unsafe.Pointer(&nlm)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return nil, fmt.Errorf("CoCreateInstance(NetworkListManager): %s", hr)
|
||||
}
|
||||
return nlm, nil
|
||||
}
|
||||
|
||||
// setNetworkCategory locates the network connection bound to adapterGUID and
|
||||
// sets the category of its parent network. Returns errAdapterNotFound if the
|
||||
// adapter is not yet visible in the NLM enumeration.
|
||||
func setNetworkCategory(adapterGUID windows.GUID, cat networkCategory) error {
|
||||
deinit, err := coInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer deinit()
|
||||
|
||||
nlm, err := createNetworkListManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer nlm.Release()
|
||||
|
||||
enum, err := nlm.GetNetworkConnections()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer enum.Release()
|
||||
|
||||
for {
|
||||
conn, err := enum.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if conn == nil {
|
||||
return errAdapterNotFound
|
||||
}
|
||||
|
||||
guid, err := conn.GetAdapterId()
|
||||
if err != nil || guid != adapterGUID {
|
||||
conn.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
net, err := conn.GetNetwork()
|
||||
conn.Release()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = net.SetCategory(cat)
|
||||
net.Release()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// applyNetworkCategory polls until the wintun adapter shows up in the NLM
|
||||
// enumeration, then sets the category. Intended to run in its own goroutine.
|
||||
func applyNetworkCategory(l *slog.Logger, adapterGUID windows.GUID, cat networkCategory) {
|
||||
// COM Init/Uninit must be paired on the same OS thread.
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
const (
|
||||
attempts = 30
|
||||
interval = 500 * time.Millisecond
|
||||
)
|
||||
for i := 0; i < attempts; i++ {
|
||||
err := setNetworkCategory(adapterGUID, cat)
|
||||
if err == nil {
|
||||
l.Info("Set Windows network category", "category", cat.String())
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errAdapterNotFound) {
|
||||
l.Warn("Failed to set Windows network category", "error", err, "category", cat.String())
|
||||
return
|
||||
}
|
||||
time.Sleep(interval)
|
||||
}
|
||||
l.Warn("Gave up waiting for adapter to appear in NLM enumeration; network category not set",
|
||||
"category", cat.String(),
|
||||
"waited", time.Duration(attempts)*interval,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_parseNetworkCategory(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
wantCat networkCategory
|
||||
wantApply bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"", 0, false, false},
|
||||
{"unset", 0, false, false},
|
||||
{" UNSET ", 0, false, false},
|
||||
{"private", networkCategoryPrivate, true, false},
|
||||
{"Private", networkCategoryPrivate, true, false},
|
||||
{" PRIVATE ", networkCategoryPrivate, true, false},
|
||||
{"public", networkCategoryPublic, true, false},
|
||||
{"PUBLIC", networkCategoryPublic, true, false},
|
||||
{"domain", networkCategoryDomainAuthenticated, true, false},
|
||||
{"DomainAuthenticated", networkCategoryDomainAuthenticated, true, false},
|
||||
{"garbage", 0, false, true},
|
||||
{"privates", 0, false, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
cat, apply, err := parseNetworkCategory(tc.in)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("parseNetworkCategory(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr)
|
||||
continue
|
||||
}
|
||||
if cat != tc.wantCat || apply != tc.wantApply {
|
||||
t.Errorf("parseNetworkCategory(%q) = (%v, %v), want (%v, %v)", tc.in, cat, apply, tc.wantCat, tc.wantApply)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test_NLM_round_trip exercises every COM call path used by setNetworkCategory
|
||||
// without mutating the host's network state. It validates the CLSID/IID
|
||||
// constants and every vtable index by enumerating connections, fetching the
|
||||
// adapter id and parent network, reading the current category, and writing it
|
||||
// back unchanged.
|
||||
//
|
||||
// Requires Windows but does not require admin or the wintun driver. Skips if
|
||||
// no network connections are available (unlikely outside of an isolated
|
||||
// container).
|
||||
func Test_NLM_round_trip(t *testing.T) {
|
||||
deinit, err := coInit()
|
||||
if err != nil {
|
||||
t.Fatalf("coInit: %v", err)
|
||||
}
|
||||
defer deinit()
|
||||
|
||||
nlm, err := createNetworkListManager()
|
||||
if err != nil {
|
||||
t.Fatalf("createNetworkListManager: %v", err)
|
||||
}
|
||||
defer nlm.Release()
|
||||
|
||||
enum, err := nlm.GetNetworkConnections()
|
||||
if err != nil {
|
||||
t.Fatalf("GetNetworkConnections: %v", err)
|
||||
}
|
||||
defer enum.Release()
|
||||
|
||||
saw := 0
|
||||
for {
|
||||
conn, err := enum.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("EnumNetworkConnections.Next: %v", err)
|
||||
}
|
||||
if conn == nil {
|
||||
break
|
||||
}
|
||||
saw++
|
||||
|
||||
if _, err := conn.GetAdapterId(); err != nil {
|
||||
conn.Release()
|
||||
t.Fatalf("INetworkConnection.GetAdapterId: %v", err)
|
||||
}
|
||||
|
||||
net, err := conn.GetNetwork()
|
||||
conn.Release()
|
||||
if err != nil {
|
||||
t.Fatalf("INetworkConnection.GetNetwork: %v", err)
|
||||
}
|
||||
|
||||
cat, err := net.GetCategory()
|
||||
if err != nil {
|
||||
net.Release()
|
||||
t.Fatalf("INetwork.GetCategory: %v", err)
|
||||
}
|
||||
// Set to the current value so the host's NLM state is unchanged but
|
||||
// SetCategory's vtable slot is still validated end-to-end.
|
||||
if err := net.SetCategory(cat); err != nil {
|
||||
net.Release()
|
||||
t.Fatalf("INetwork.SetCategory(%v): %v", cat, err)
|
||||
}
|
||||
net.Release()
|
||||
}
|
||||
|
||||
if saw == 0 {
|
||||
t.Skip("no NLM network connections available; skipping round-trip")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
//go:build (amd64 || arm64) && !e2e_testing
|
||||
// +build amd64 arm64
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/slackhq/nebula/wfp"
|
||||
)
|
||||
|
||||
// installInterfaceBypass installs a WFP PERMIT filter scoped to the wintun interface LUID so inbound traffic on the
|
||||
// nebula adapter bypasses Windows Defender Firewall.
|
||||
func installInterfaceBypass(l *slog.Logger, luid uint64) closer {
|
||||
s, err := wfp.PermitInterface(luid)
|
||||
if err != nil {
|
||||
l.Warn("Failed to install WFP bypass filters on nebula interface", "error", err)
|
||||
return nil
|
||||
}
|
||||
l.Info("Installed WFP filters bypassing Windows Defender Firewall on nebula interface")
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import "log/slog"
|
||||
|
||||
// installInterfaceBypass is a no-op on windows-386 because we don't currently build for it.
|
||||
func installInterfaceBypass(_ *slog.Logger, _ uint64) closer {
|
||||
return nil
|
||||
}
|
||||
+54
-16
@@ -25,15 +25,24 @@ import (
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
type closer interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
||||
|
||||
type winTun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *slog.Logger
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
guid windows.GUID
|
||||
networkCategory networkCategory
|
||||
setCategory bool
|
||||
bypassWDF bool
|
||||
wdfBypass closer
|
||||
l *slog.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
}
|
||||
@@ -54,11 +63,20 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w
|
||||
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
||||
}
|
||||
|
||||
cat, setCat, err := parseNetworkCategory(c.GetString("tun.network_category", "private"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &winTun{
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
guid: *guid,
|
||||
networkCategory: cat,
|
||||
setCategory: setCat,
|
||||
bypassWDF: c.GetBool("tun.windows_bypass_wdf", true),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -142,6 +160,17 @@ func (t *winTun) Activate() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.setCategory {
|
||||
// The wintun adapter takes a moment to register with the Network List
|
||||
// Manager, so we apply the category in the background and retry until
|
||||
// it shows up.
|
||||
go applyNetworkCategory(t.l, t.guid, t.networkCategory)
|
||||
}
|
||||
|
||||
if t.bypassWDF {
|
||||
t.wdfBypass = installInterfaceBypass(t.l, uint64(t.tun.LUID()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -156,11 +185,8 @@ func (t *winTun) addRoutes(logErrors bool) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add our unsafe route
|
||||
// Windows does not support multipath routes natively, so we install only a single route.
|
||||
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
|
||||
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
|
||||
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
||||
// Add our unsafe route as an on-link route to the nebula tun device.
|
||||
err := luid.AddRoute(r.Cidr, unspecifiedNextHop(r.Cidr), uint32(r.Metric))
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
@@ -206,7 +232,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
||||
}
|
||||
|
||||
// See comment on luid.AddRoute
|
||||
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||
err := luid.DeleteRoute(r.Cidr, unspecifiedNextHop(r.Cidr))
|
||||
if err != nil {
|
||||
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||
} else {
|
||||
@@ -258,9 +284,21 @@ func (t *winTun) Close() error {
|
||||
_ = luid.FlushDNS(windows.AF_INET)
|
||||
_ = luid.FlushDNS(windows.AF_INET6)
|
||||
|
||||
if t.wdfBypass != nil {
|
||||
t.wdfBypass.Close()
|
||||
t.wdfBypass = nil
|
||||
}
|
||||
|
||||
return t.tun.Close()
|
||||
}
|
||||
|
||||
func unspecifiedNextHop(p netip.Prefix) netip.Addr {
|
||||
if p.Addr().Is4() {
|
||||
return netip.IPv4Unspecified()
|
||||
}
|
||||
return netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
|
||||
// GUID is 128 bit
|
||||
hash := crypto.MD5.New()
|
||||
|
||||
@@ -99,12 +99,10 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
var currentState *CertState
|
||||
if initial {
|
||||
cipher = c.GetString("cipher", "aes")
|
||||
//TODO: this sucks and we should make it not a global
|
||||
switch cipher {
|
||||
case "aes":
|
||||
noiseEndianness = binary.BigEndian
|
||||
case "chachapoly":
|
||||
noiseEndianness = binary.LittleEndian
|
||||
case "aes", "chachapoly":
|
||||
// Each post-handshake CipherState in noiseutil hardcodes its own
|
||||
// nonce endianness now, so there's nothing to set up here.
|
||||
default:
|
||||
return util.NewContextualError(
|
||||
"unknown cipher",
|
||||
|
||||
@@ -1,24 +1,70 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
// holepunchQueueSize buffers the channel that pending holepunchJobs land on after their delay timer fires.
|
||||
const holepunchQueueSize = 64
|
||||
|
||||
// holepunchJob is one scheduled item delivered to the worker goroutine.
|
||||
// - target valid -> send a UDP punch to target. vpnAddr, if set, is the peer's vpn addr carried for log context.
|
||||
// - target invalid, vpnAddr valid -> send an encrypted test packet to vpnAddr (a "punchback").
|
||||
type holepunchJob struct {
|
||||
target netip.AddrPort
|
||||
vpnAddr netip.Addr
|
||||
}
|
||||
|
||||
// lighthouseChecker is the slice of LightHouse that Punchy actually needs.
|
||||
// Defined here so Punchy doesn't take a *LightHouse dependency (LightHouse
|
||||
// already holds a *Punchy, and the bidirectional pointer reference is awkward
|
||||
// even within the same package). Tests can also substitute a fake.
|
||||
type lighthouseChecker interface {
|
||||
IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool
|
||||
}
|
||||
|
||||
type Punchy struct {
|
||||
punch atomic.Bool
|
||||
respond atomic.Bool
|
||||
delay atomic.Int64
|
||||
respondDelay atomic.Int64
|
||||
punchEverything atomic.Bool
|
||||
l *slog.Logger
|
||||
|
||||
sched *Scheduler[holepunchJob]
|
||||
punchConn udp.Conn
|
||||
metricHolepunchTx metrics.Counter
|
||||
metricPunchyTx metrics.Counter
|
||||
|
||||
ctx context.Context
|
||||
ifce EncWriter
|
||||
hm *HostMap
|
||||
lh lighthouseChecker
|
||||
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
|
||||
p := &Punchy{l: l}
|
||||
func NewPunchyFromConfig(l *slog.Logger, c *config.C, punchConn udp.Conn) *Punchy {
|
||||
p := &Punchy{
|
||||
l: l,
|
||||
punchConn: punchConn,
|
||||
sched: NewScheduler[holepunchJob](holepunchQueueSize),
|
||||
metricPunchyTx: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
||||
}
|
||||
|
||||
if c.GetBool("stats.lighthouse_metrics", false) {
|
||||
p.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
|
||||
} else {
|
||||
p.metricHolepunchTx = metrics.NilCounter{}
|
||||
}
|
||||
|
||||
p.reload(c, true)
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
@@ -29,7 +75,7 @@ func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
|
||||
}
|
||||
|
||||
func (p *Punchy) reload(c *config.C, initial bool) {
|
||||
if initial {
|
||||
if initial || c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
|
||||
var yes bool
|
||||
if c.IsSet("punchy.punch") {
|
||||
yes = c.GetBool("punchy.punch", false)
|
||||
@@ -38,16 +84,15 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
||||
yes = c.GetBool("punchy", false)
|
||||
}
|
||||
|
||||
p.punch.Store(yes)
|
||||
if yes {
|
||||
old := p.punch.Swap(yes)
|
||||
switch {
|
||||
case initial && yes:
|
||||
p.l.Info("punchy enabled")
|
||||
} else {
|
||||
case initial:
|
||||
p.l.Info("punchy disabled")
|
||||
case old != yes:
|
||||
p.l.Info("punchy.punch changed", "punch", yes)
|
||||
}
|
||||
|
||||
} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
|
||||
//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
|
||||
p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") {
|
||||
@@ -59,52 +104,132 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
||||
yes = c.GetBool("punch_back", false)
|
||||
}
|
||||
|
||||
p.respond.Store(yes)
|
||||
|
||||
if !initial {
|
||||
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
|
||||
old := p.respond.Swap(yes)
|
||||
if !initial && old != yes {
|
||||
p.l.Info("punchy.respond changed", "respond", yes)
|
||||
}
|
||||
}
|
||||
|
||||
//NOTE: this will not apply to any in progress operations, only the next one
|
||||
if initial || c.HasChanged("punchy.delay") {
|
||||
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
|
||||
if !initial {
|
||||
p.l.Info("punchy.delay changed", "delay", p.GetDelay())
|
||||
newDelay := int64(c.GetDuration("punchy.delay", time.Second))
|
||||
old := p.delay.Swap(newDelay)
|
||||
if !initial && old != newDelay {
|
||||
p.l.Info("punchy.delay changed", "delay", time.Duration(newDelay))
|
||||
}
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("punchy.target_all_remotes") {
|
||||
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
|
||||
if !initial {
|
||||
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
|
||||
yes := c.GetBool("punchy.target_all_remotes", false)
|
||||
old := p.punchEverything.Swap(yes)
|
||||
if !initial && old != yes {
|
||||
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", yes)
|
||||
}
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("punchy.respond_delay") {
|
||||
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
|
||||
if !initial {
|
||||
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
|
||||
newDelay := int64(c.GetDuration("punchy.respond_delay", 5*time.Second))
|
||||
old := p.respondDelay.Swap(newDelay)
|
||||
if !initial && old != newDelay {
|
||||
p.l.Info("punchy.respond_delay changed", "respond_delay", time.Duration(newDelay))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Punchy) GetPunch() bool {
|
||||
return p.punch.Load()
|
||||
// Schedule queues a punch packet to target, to be sent after the configured delay.
|
||||
// vpnAddr is the peer's vpn addr, used for log context when the packet actually fires.
|
||||
// No-op if target is not a valid AddrPort or if Start has not yet been called. Safe to call from any goroutine.
|
||||
func (p *Punchy) Schedule(target netip.AddrPort, vpnAddr netip.Addr) {
|
||||
if !target.IsValid() || p.ctx == nil {
|
||||
return
|
||||
}
|
||||
p.scheduleJob(holepunchJob{target: target, vpnAddr: vpnAddr}, time.Duration(p.delay.Load()))
|
||||
}
|
||||
|
||||
func (p *Punchy) GetRespond() bool {
|
||||
return p.respond.Load()
|
||||
// ScheduleRespond queues a punchback test packet to vpnAddr after the configured respond delay,
|
||||
// gated on punchy.respond. No-op when respond is disabled or before Start has been called.
|
||||
func (p *Punchy) ScheduleRespond(vpnAddr netip.Addr) {
|
||||
if !p.respond.Load() || p.ctx == nil {
|
||||
return
|
||||
}
|
||||
p.scheduleJob(holepunchJob{vpnAddr: vpnAddr}, time.Duration(p.respondDelay.Load()))
|
||||
}
|
||||
|
||||
func (p *Punchy) GetDelay() time.Duration {
|
||||
return (time.Duration)(p.delay.Load())
|
||||
// scheduleJob delegates to the pooled Scheduler.
|
||||
// The callback observes p.ctx so a job that becomes due after Stop is dropped instead of queued.
|
||||
func (p *Punchy) scheduleJob(job holepunchJob, delay time.Duration) {
|
||||
p.sched.Schedule(p.ctx, job, delay)
|
||||
}
|
||||
|
||||
func (p *Punchy) GetRespondDelay() time.Duration {
|
||||
return (time.Duration)(p.respondDelay.Load())
|
||||
// SendPunch sends an immediate keepalive punch for an idle hostinfo.
|
||||
// The configured punchy.target_all_remotes mode picks the targets. Gated on punchy.punch and the lighthouse-skip rule
|
||||
// (lighthouses don't get keepalive punches because the regular update interval keeps their NAT state warm).
|
||||
func (p *Punchy) SendPunch(hostinfo *HostInfo) {
|
||||
if !p.punch.Load() {
|
||||
return
|
||||
}
|
||||
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
||||
return
|
||||
}
|
||||
|
||||
if p.punchEverything.Load() {
|
||||
p.sendPunchToAllRemotes(hostinfo)
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
p.metricPunchyTx.Inc(1)
|
||||
p.punchConn.WriteTo([]byte{1}, hostinfo.remote)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Punchy) GetTargetEverything() bool {
|
||||
return p.punchEverything.Load()
|
||||
// SendPunchToAll punches every known remote for hostinfo, but only when punchy.target_all_remotes is enabled.
|
||||
// The connection manager calls this during outbound-only traffic: the outbound traffic itself keeps the primary's
|
||||
// NAT state warm, but non-primary remotes need separate refresh, so we fan out to all of them (the redundant
|
||||
// primary punch is harmless). Gated on punchy.punch and the lighthouse-skip rule.
|
||||
func (p *Punchy) SendPunchToAll(hostinfo *HostInfo) {
|
||||
if !p.punchEverything.Load() {
|
||||
return
|
||||
}
|
||||
if !p.punch.Load() {
|
||||
return
|
||||
}
|
||||
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
||||
return
|
||||
}
|
||||
p.sendPunchToAllRemotes(hostinfo)
|
||||
}
|
||||
|
||||
func (p *Punchy) sendPunchToAllRemotes(hostinfo *HostInfo) {
|
||||
hostinfo.remotes.ForEach(p.hm.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||
p.metricPunchyTx.Inc(1)
|
||||
p.punchConn.WriteTo([]byte{1}, addr)
|
||||
})
|
||||
}
|
||||
|
||||
// Start wires the runtime dependencies and spawns the scheduler worker.
|
||||
func (p *Punchy) Start(ctx context.Context, ifce EncWriter, hm *HostMap, lh lighthouseChecker) {
|
||||
p.ctx = ctx
|
||||
p.ifce = ifce
|
||||
p.hm = hm
|
||||
p.lh = lh
|
||||
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
empty := []byte{0}
|
||||
|
||||
go p.sched.Run(ctx, func(job holepunchJob) {
|
||||
switch {
|
||||
case job.target.IsValid():
|
||||
if p.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
p.l.Debug("Punching", "target", job.target, "vpnAddr", job.vpnAddr)
|
||||
}
|
||||
p.metricHolepunchTx.Inc(1)
|
||||
p.punchConn.WriteTo(empty, job.target)
|
||||
case job.vpnAddr.IsValid():
|
||||
// A nebula test packet to the host trying to contact us.
|
||||
// In the case of a double nat or other difficult scenario, this may help establish a tunnel.
|
||||
if p.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
p.l.Debug("Sending a nebula test packet", "vpnAddr", job.vpnAddr)
|
||||
}
|
||||
p.ifce.SendMessageToVpnAddr(header.Test, header.TestRequest, job.vpnAddr, []byte(""), nb, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+40
-41
@@ -17,42 +17,42 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
||||
c := config.NewC(l)
|
||||
|
||||
// Test defaults
|
||||
p := NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.False(t, p.GetPunch())
|
||||
assert.False(t, p.GetRespond())
|
||||
assert.Equal(t, time.Second, p.GetDelay())
|
||||
assert.Equal(t, 5*time.Second, p.GetRespondDelay())
|
||||
p := NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||
assert.False(t, p.punch.Load())
|
||||
assert.False(t, p.respond.Load())
|
||||
assert.Equal(t, time.Second, time.Duration(p.delay.Load()))
|
||||
assert.Equal(t, 5*time.Second, time.Duration(p.respondDelay.Load()))
|
||||
|
||||
// punchy deprecation
|
||||
c.Settings["punchy"] = true
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.True(t, p.GetPunch())
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||
assert.True(t, p.punch.Load())
|
||||
|
||||
// punchy.punch
|
||||
c.Settings["punchy"] = map[string]any{"punch": true}
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.True(t, p.GetPunch())
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||
assert.True(t, p.punch.Load())
|
||||
|
||||
// punch_back deprecation
|
||||
c.Settings["punch_back"] = true
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.True(t, p.GetRespond())
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||
assert.True(t, p.respond.Load())
|
||||
|
||||
// punchy.respond
|
||||
c.Settings["punchy"] = map[string]any{"respond": true}
|
||||
c.Settings["punch_back"] = false
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.True(t, p.GetRespond())
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||
assert.True(t, p.respond.Load())
|
||||
|
||||
// punchy.delay
|
||||
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.Equal(t, time.Minute, p.GetDelay())
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||
assert.Equal(t, time.Minute, time.Duration(p.delay.Load()))
|
||||
|
||||
// punchy.respond_delay
|
||||
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||
assert.Equal(t, time.Minute, time.Duration(p.respondDelay.Load()))
|
||||
}
|
||||
|
||||
func TestPunchy_reload(t *testing.T) {
|
||||
@@ -61,35 +61,34 @@ func TestPunchy_reload(t *testing.T) {
|
||||
delay, _ := time.ParseDuration("1m")
|
||||
require.NoError(t, c.LoadString(`
|
||||
punchy:
|
||||
punch: false
|
||||
delay: 1m
|
||||
respond: false
|
||||
`))
|
||||
p := NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.Equal(t, delay, p.GetDelay())
|
||||
assert.False(t, p.GetRespond())
|
||||
p := NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||
assert.False(t, p.punch.Load())
|
||||
assert.Equal(t, delay, time.Duration(p.delay.Load()))
|
||||
assert.False(t, p.respond.Load())
|
||||
|
||||
newDelay, _ := time.ParseDuration("10m")
|
||||
require.NoError(t, c.ReloadConfigString(`
|
||||
punchy:
|
||||
punch: true
|
||||
delay: 10m
|
||||
respond: true
|
||||
`))
|
||||
p.reload(c, false)
|
||||
assert.Equal(t, newDelay, p.GetDelay())
|
||||
assert.True(t, p.GetRespond())
|
||||
assert.True(t, p.punch.Load())
|
||||
assert.Equal(t, newDelay, time.Duration(p.delay.Load()))
|
||||
assert.True(t, p.respond.Load())
|
||||
}
|
||||
|
||||
// The tests below pin the shape of each log line Punchy produces so changes
|
||||
// cannot silently break whatever operators are grepping for. The assertions
|
||||
// are on the structured message + attrs (e.g. "punchy.respond changed" with
|
||||
// a respond=true field) rather than a formatted string.
|
||||
//
|
||||
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
|
||||
// not supported" warning whenever any key under punchy changes, because of
|
||||
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
|
||||
// punchy form. The tests filter by message rather than asserting total
|
||||
// entry counts so that warning is tolerated without being locked into
|
||||
// the format.
|
||||
// a respond=true field) rather than a formatted string. Tests filter by
|
||||
// message rather than asserting total entry counts so unrelated info lines
|
||||
// are tolerated without being locked into the format.
|
||||
|
||||
type capturedEntry struct {
|
||||
Level slog.Level
|
||||
@@ -145,7 +144,7 @@ func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
|
||||
|
||||
NewPunchyFromConfig(l, c)
|
||||
NewPunchyFromConfig(l, c, nil)
|
||||
|
||||
entry := findEntry(t, hook.entries, "punchy enabled")
|
||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||
@@ -157,32 +156,32 @@ func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||
|
||||
NewPunchyFromConfig(l, c)
|
||||
NewPunchyFromConfig(l, c, nil)
|
||||
|
||||
entry := findEntry(t, hook.entries, "punchy disabled")
|
||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||
assert.Empty(t, entry.Attrs)
|
||||
}
|
||||
|
||||
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
|
||||
func TestPunchy_LogFormat_ReloadPunch(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
NewPunchyFromConfig(l, c, nil)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
|
||||
|
||||
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
|
||||
assert.Equal(t, slog.LevelWarn, entry.Level)
|
||||
assert.Empty(t, entry.Attrs)
|
||||
entry := findEntry(t, hook.entries, "punchy.punch changed")
|
||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||
assert.Equal(t, map[string]any{"punch": true}, entry.Attrs)
|
||||
}
|
||||
|
||||
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
NewPunchyFromConfig(l, c, nil)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
|
||||
@@ -196,7 +195,7 @@ func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
NewPunchyFromConfig(l, c, nil)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
|
||||
@@ -210,7 +209,7 @@ func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
NewPunchyFromConfig(l, c, nil)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
|
||||
@@ -224,7 +223,7 @@ func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
NewPunchyFromConfig(l, c, nil)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
|
||||
|
||||
+37
-18
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
@@ -57,14 +58,25 @@ func (rm *relayManager) GetUseRelays() bool {
|
||||
// For each candidate relay it either kicks off a handshake to the relay, sends a CreateRelayRequest, retransmits
|
||||
// one that may have been lost, or, once the relay is Established, forwards the in-progress
|
||||
// stage 0 handshake packet for vpnIp through it.
|
||||
func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *HostInfo, stage0 []byte) {
|
||||
func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hh *HandshakeHostInfo, stage0 []byte) {
|
||||
hostinfo := hh.hostinfo
|
||||
if !rm.GetUseRelays() || len(hostinfo.remotes.relays) == 0 {
|
||||
hh.lastRelays = nil
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
|
||||
relays := hostinfo.remotes.relays
|
||||
listLevel := slog.LevelDebug
|
||||
prior := hh.lastRelays
|
||||
if !slices.Equal(relays, prior) {
|
||||
listLevel = slog.LevelInfo
|
||||
hh.lastRelays = slices.Clone(relays)
|
||||
}
|
||||
hl := hostinfo.logger(rm.l)
|
||||
hl.Log(context.Background(), listLevel, "Attempt to relay through hosts", "relays", relays)
|
||||
|
||||
// Send a RelayRequest to all known Relay IP's
|
||||
for _, relay := range hostinfo.remotes.relays {
|
||||
for _, relay := range relays {
|
||||
// Don't relay through the host I'm trying to connect to
|
||||
if relay == vpnIp {
|
||||
continue
|
||||
@@ -75,12 +87,19 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
|
||||
continue
|
||||
}
|
||||
|
||||
// Each relay's per-attempt log fires at Info on the first time we hit it and Debug after that.
|
||||
level := slog.LevelInfo
|
||||
if slices.Contains(prior, relay) {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
relayHostInfo := rm.hostmap.QueryVpnAddr(relay)
|
||||
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
||||
hostinfo.logger(rm.l).Info("Establish tunnel to relay target", "relay", relay.String())
|
||||
hl.Log(context.Background(), level, "Establish tunnel to relay target", "relay", relay.String())
|
||||
f.Handshake(relay)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check the relay HostInfo to see if we already established a relay through
|
||||
existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
|
||||
if !ok {
|
||||
@@ -88,7 +107,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
|
||||
if relayHostInfo.remote.IsValid() {
|
||||
idx, err := AddRelay(rm.l, relayHostInfo, rm.hostmap, vpnIp, nil, TerminalType, Requested)
|
||||
if err != nil {
|
||||
hostinfo.logger(rm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
|
||||
hl.Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
|
||||
}
|
||||
|
||||
m := NebulaControl{
|
||||
@@ -99,12 +118,12 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
|
||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||
case cert.Version1:
|
||||
if !f.myVpnAddrs[0].Is4() {
|
||||
hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
if !vpnIp.Is4() {
|
||||
hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -116,16 +135,16 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
|
||||
m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0])
|
||||
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||
default:
|
||||
hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay")
|
||||
hl.Error("Unknown certificate version found while creating relay")
|
||||
continue
|
||||
}
|
||||
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err)
|
||||
hl.Error("Failed to marshal Control message to create relay", "error", err)
|
||||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.Info("send CreateRelayRequest",
|
||||
rm.l.Log(context.Background(), level, "send CreateRelayRequest",
|
||||
"relayFrom", f.myVpnAddrs[0],
|
||||
"relayTo", vpnIp,
|
||||
"initiatorRelayIndex", idx,
|
||||
@@ -138,14 +157,14 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
|
||||
|
||||
switch existingRelay.State {
|
||||
case Established:
|
||||
hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String())
|
||||
hl.Log(context.Background(), level, "Send handshake via relay", "relay", relay.String())
|
||||
f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false)
|
||||
case Disestablished:
|
||||
// Mark this relay as 'requested'
|
||||
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
|
||||
fallthrough
|
||||
case Requested:
|
||||
hostinfo.logger(rm.l).Info("Re-send CreateRelay request", "relay", relay.String())
|
||||
hl.Log(context.Background(), level, "Re-send CreateRelay request", "relay", relay.String())
|
||||
// Re-send the CreateRelay request, in case the previous one was lost.
|
||||
m := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
@@ -155,12 +174,12 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
|
||||
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||
case cert.Version1:
|
||||
if !f.myVpnAddrs[0].Is4() {
|
||||
hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
if !vpnIp.Is4() {
|
||||
hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -172,16 +191,16 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
|
||||
m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0])
|
||||
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||
default:
|
||||
hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay")
|
||||
hl.Error("Unknown certificate version found while creating relay")
|
||||
continue
|
||||
}
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err)
|
||||
hl.Error("Failed to marshal Control message to create relay", "error", err)
|
||||
} else {
|
||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.Info("send CreateRelayRequest",
|
||||
rm.l.Log(context.Background(), level, "send CreateRelayRequest",
|
||||
"relayFrom", f.myVpnAddrs[0],
|
||||
"relayTo", vpnIp,
|
||||
"initiatorRelayIndex", existingRelay.LocalIndex,
|
||||
@@ -192,7 +211,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
|
||||
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
|
||||
fallthrough
|
||||
default:
|
||||
hostinfo.logger(rm.l).Error("Relay unexpected state",
|
||||
hl.Error("Relay unexpected state",
|
||||
"vpnIp", vpnIp,
|
||||
"state", existingRelay.State,
|
||||
"relay", relay,
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestStartRelaysLogDedupe verifies that repeated attempts with the same relay set drop the log
|
||||
// chatter to Debug, mirroring how the normal handshake retry loop quiets down once it's already
|
||||
// announced its targets.
|
||||
func TestStartRelaysLogDedupe(t *testing.T) {
|
||||
vpnIp := netip.MustParseAddr("100.64.99.4")
|
||||
otherRelay := netip.MustParseAddr("100.64.99.5")
|
||||
|
||||
newHH := func() *HandshakeHostInfo {
|
||||
// Use the target's own vpnIp as the "relay" so the loop body skips it without
|
||||
// touching any sender-side state. That isolates the test to the level-selection
|
||||
// behavior of the top-level "Attempt to relay through hosts" log.
|
||||
hostinfo := &HostInfo{
|
||||
vpnAddrs: []netip.Addr{vpnIp},
|
||||
localIndexId: 1,
|
||||
remotes: NewRemoteList([]netip.Addr{vpnIp}, nil),
|
||||
}
|
||||
hostinfo.remotes.relays = []netip.Addr{vpnIp}
|
||||
return &HandshakeHostInfo{hostinfo: hostinfo}
|
||||
}
|
||||
|
||||
// Park any extra relay addresses we'll introduce mid-test in myVpnAddrsTable so the loop
|
||||
// body always skips before touching f.Handshake (which would need a real handshakeManager).
|
||||
addrTable := new(bart.Lite)
|
||||
addrTable.Insert(netip.PrefixFrom(otherRelay, otherRelay.BitLen()))
|
||||
f := &Interface{myVpnAddrsTable: addrTable}
|
||||
|
||||
newRM := func(buf *bytes.Buffer) *relayManager {
|
||||
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||
rm := &relayManager{l: l, hostmap: newHostMap(l)}
|
||||
rm.useRelays.Store(true)
|
||||
return rm
|
||||
}
|
||||
|
||||
const msg = `msg="Attempt to relay through hosts"`
|
||||
|
||||
t.Run("first attempt logs at Info", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
rm := newRM(&buf)
|
||||
hh := newHH()
|
||||
rm.StartRelays(f, vpnIp, hh, nil)
|
||||
assert.Equal(t, []netip.Addr{vpnIp}, hh.lastRelays, "lastRelays should record the relay set we just attempted")
|
||||
assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info level on first attempt")
|
||||
})
|
||||
|
||||
t.Run("repeat attempt with same relays drops to Debug", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
rm := newRM(&buf)
|
||||
hh := newHH()
|
||||
rm.StartRelays(f, vpnIp, hh, nil)
|
||||
first := append([]netip.Addr(nil), hh.lastRelays...)
|
||||
buf.Reset()
|
||||
rm.StartRelays(f, vpnIp, hh, nil)
|
||||
assert.Equal(t, first, hh.lastRelays)
|
||||
assert.Contains(t, buf.String(), "level=DEBUG "+msg, "expected Debug level on identical retry")
|
||||
assert.NotContains(t, buf.String(), "level=INFO "+msg, "Info should not fire on identical retry")
|
||||
})
|
||||
|
||||
t.Run("changed relay list bumps back to Info", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
rm := newRM(&buf)
|
||||
hh := newHH()
|
||||
rm.StartRelays(f, vpnIp, hh, nil)
|
||||
buf.Reset()
|
||||
|
||||
// The lighthouse handed us a new set this round.
|
||||
hh.hostinfo.remotes.relays = []netip.Addr{vpnIp, otherRelay}
|
||||
|
||||
rm.StartRelays(f, vpnIp, hh, nil)
|
||||
assert.Equal(t, []netip.Addr{vpnIp, otherRelay}, hh.lastRelays)
|
||||
assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info when the relay list changes")
|
||||
})
|
||||
|
||||
t.Run("disabled relays clears lastRelays and emits no Attempt log", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
rm := newRM(&buf)
|
||||
rm.useRelays.Store(false)
|
||||
hh := newHH()
|
||||
hh.lastRelays = []netip.Addr{vpnIp}
|
||||
|
||||
rm.StartRelays(f, vpnIp, hh, nil)
|
||||
assert.Nil(t, hh.lastRelays, "with relays disabled lastRelays should be cleared")
|
||||
assert.NotContains(t, buf.String(), msg, "should not log when we shortcut out")
|
||||
})
|
||||
}
|
||||
@@ -239,6 +239,31 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
|
||||
r.hr = hr
|
||||
}
|
||||
|
||||
// ResetForOwner zeros the reported address slices for the given owner and marks the addrs list dirty.
|
||||
// Any pending hostname resolution will be canceled.
|
||||
func (r *RemoteList) ResetForOwner(ownerVpnAddr netip.Addr) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.hr.Cancel()
|
||||
if c, ok := r.cache[ownerVpnAddr]; ok {
|
||||
if c.v4 != nil {
|
||||
c.v4.reported = c.v4.reported[:0]
|
||||
}
|
||||
if c.v6 != nil {
|
||||
c.v6.reported = c.v6.reported[:0]
|
||||
}
|
||||
}
|
||||
r.shouldRebuild = true
|
||||
}
|
||||
|
||||
// ClearHostnameResults cancels the in-flight DNS resolver goroutine (if any) and drops the resolved IP cache.
|
||||
func (r *RemoteList) ClearHostnameResults() {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.unlockedSetHostnamesResults(nil)
|
||||
r.shouldRebuild = true
|
||||
}
|
||||
|
||||
// Len locks and reports the size of the deduplicated address list
|
||||
// The deduplication work may need to occur here, so you must pass preferredRanges
|
||||
func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
|
||||
|
||||
@@ -6,8 +6,22 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// trackedHostnameResults builds a *hostnamesResults with a known cancel function and a
|
||||
// pre-populated ips map so tests can assert cancellation and verify previously-resolved
|
||||
// IPs survive a cancel without spinning up a real DNS resolver.
|
||||
func trackedHostnameResults(cancelFn func(), addrs ...string) *hostnamesResults {
|
||||
hr := &hostnamesResults{cancelFn: cancelFn}
|
||||
ips := map[netip.AddrPort]struct{}{}
|
||||
for _, a := range addrs {
|
||||
ips[netip.MustParseAddrPort(a)] = struct{}{}
|
||||
}
|
||||
hr.ips.Store(&ips)
|
||||
return hr
|
||||
}
|
||||
|
||||
func TestRemoteList_Rebuild(t *testing.T) {
|
||||
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
|
||||
rl.unlockedSetV4(
|
||||
@@ -112,6 +126,81 @@ func TestRemoteList_Rebuild(t *testing.T) {
|
||||
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
|
||||
}
|
||||
|
||||
func TestRemoteList_ResetForOwner(t *testing.T) {
|
||||
ourselves := netip.MustParseAddr("10.0.0.1")
|
||||
otherOwner := netip.MustParseAddr("10.0.0.2")
|
||||
vpnAddr := netip.MustParseAddr("10.0.0.99")
|
||||
|
||||
rl := NewRemoteList([]netip.Addr{vpnAddr}, nil)
|
||||
rl.unlockedSetV4(ourselves, vpnAddr,
|
||||
[]*V4AddrPort{newIp4AndPortFromString("1.1.1.1:4242")},
|
||||
func(netip.Addr, *V4AddrPort) bool { return true },
|
||||
)
|
||||
rl.unlockedSetV6(ourselves, vpnAddr,
|
||||
[]*V6AddrPort{newIp6AndPortFromString("[1::1]:4242")},
|
||||
func(netip.Addr, *V6AddrPort) bool { return true },
|
||||
)
|
||||
rl.unlockedSetV4(otherOwner, vpnAddr,
|
||||
[]*V4AddrPort{newIp4AndPortFromString("2.2.2.2:4242")},
|
||||
func(netip.Addr, *V4AddrPort) bool { return true },
|
||||
)
|
||||
|
||||
canceled := 0
|
||||
hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")
|
||||
rl.Lock()
|
||||
rl.unlockedSetHostnamesResults(hr)
|
||||
rl.Unlock()
|
||||
|
||||
rl.ResetForOwner(ourselves)
|
||||
|
||||
rl.RLock()
|
||||
defer rl.RUnlock()
|
||||
assert.Empty(t, rl.cache[ourselves].v4.reported, "our v4 reported should be cleared")
|
||||
assert.Empty(t, rl.cache[ourselves].v6.reported, "our v6 reported should be cleared")
|
||||
assert.Len(t, rl.cache[otherOwner].v4.reported, 1, "other owner's contribution must be preserved")
|
||||
assert.Equal(t, "2.2.2.2:4242", protoV4AddrPortToNetAddrPort(rl.cache[otherOwner].v4.reported[0]).String())
|
||||
assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled")
|
||||
assert.Same(t, hr, rl.hr, "hostnamesResults must be preserved so DNS-resolved IPs keep feeding addrs until replaced")
|
||||
assert.NotEmpty(t, rl.hr.GetAddrs(), "previously-resolved IPs should still be readable after cancel")
|
||||
assert.True(t, rl.shouldRebuild, "shouldRebuild must be set so the next Rebuild recomputes addrs")
|
||||
}
|
||||
|
||||
func TestRemoteList_ResetForOwner_NoEntry(t *testing.T) {
|
||||
// An owner with no cache entry must not panic; shouldRebuild is still set and any
|
||||
// existing hostnamesResults is canceled.
|
||||
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil)
|
||||
canceled := 0
|
||||
rl.Lock()
|
||||
rl.unlockedSetHostnamesResults(trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242"))
|
||||
rl.Unlock()
|
||||
|
||||
rl.ResetForOwner(netip.MustParseAddr("10.0.0.1"))
|
||||
|
||||
rl.RLock()
|
||||
defer rl.RUnlock()
|
||||
assert.Equal(t, 1, canceled)
|
||||
assert.True(t, rl.shouldRebuild)
|
||||
}
|
||||
|
||||
func TestRemoteList_ClearHostnameResults(t *testing.T) {
|
||||
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil)
|
||||
|
||||
canceled := 0
|
||||
hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")
|
||||
rl.Lock()
|
||||
rl.unlockedSetHostnamesResults(hr)
|
||||
rl.Unlock()
|
||||
require.NotEmpty(t, hr.GetAddrs(), "hostnamesResults should have its fastrack IPs populated")
|
||||
|
||||
rl.ClearHostnameResults()
|
||||
|
||||
rl.RLock()
|
||||
defer rl.RUnlock()
|
||||
assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled")
|
||||
assert.Nil(t, rl.hr, "hostnamesResults should be dropped")
|
||||
assert.True(t, rl.shouldRebuild)
|
||||
}
|
||||
|
||||
func BenchmarkFullRebuild(b *testing.B) {
|
||||
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
|
||||
rl.unlockedSetV4(
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Scheduler is an allocation-conscious dispatch primitive for delayed work.
|
||||
// Pending items are handed to time.AfterFunc, and ready items land on a worker
|
||||
// channel for centralized dispatch in fire-time order.
|
||||
//
|
||||
// Pick a Scheduler when fire timing matters (exact deadlines, no bucketing) or when the scheduling
|
||||
// rate is uneven enough that idle CPU matters. Each fire is a runtime-spawned goroutine running the callback before
|
||||
// delivering to the worker, which is fine at sparse rates but adds up at line rate.
|
||||
//
|
||||
// Pick a TimerWheel when scheduling is high-rate and uniform: its O(1) insert, internal item cache,
|
||||
// and bucket-batched dispatch are cheaper at scale.
|
||||
// The caller drives the tick loop (Advance/Purge) and pays for fires at bucket boundaries rather than exact deadlines.
|
||||
type Scheduler[T any] struct {
|
||||
queue chan T
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
type schedItem[T any] struct {
|
||||
val T
|
||||
ctx context.Context
|
||||
s *Scheduler[T]
|
||||
timer *time.Timer
|
||||
fire func()
|
||||
}
|
||||
|
||||
// NewScheduler builds a Scheduler whose worker channel is sized to queueSize.
|
||||
// The buffer absorbs bursts of timers firing close together without
|
||||
// blocking the runtime's callback goroutines on the worker.
|
||||
func NewScheduler[T any](queueSize int) *Scheduler[T] {
|
||||
s := &Scheduler[T]{
|
||||
queue: make(chan T, queueSize),
|
||||
}
|
||||
s.pool.New = func() any {
|
||||
si := &schedItem[T]{s: s}
|
||||
// fire is allocated exactly once per pool-resident item.
|
||||
// The closure captures only `si`, which stays stable for the item's lifetime.
|
||||
si.fire = func() {
|
||||
select {
|
||||
case si.s.queue <- si.val:
|
||||
case <-si.ctx.Done():
|
||||
}
|
||||
var zero T
|
||||
si.val = zero
|
||||
si.ctx = nil
|
||||
si.s.pool.Put(si)
|
||||
}
|
||||
return si
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Schedule arranges item to be delivered to the worker after delay.
|
||||
// The runtime's timer heap handles the wait, so the scheduler itself burns no CPU while idle.
|
||||
// The callback observes ctx: if ctx is cancelled before the timer fires, the item is dropped instead of queued.
|
||||
func (s *Scheduler[T]) Schedule(ctx context.Context, item T, delay time.Duration) {
|
||||
si := s.pool.Get().(*schedItem[T])
|
||||
si.val = item
|
||||
si.ctx = ctx
|
||||
if si.timer == nil {
|
||||
si.timer = time.AfterFunc(delay, si.fire)
|
||||
} else {
|
||||
si.timer.Reset(delay)
|
||||
}
|
||||
}
|
||||
|
||||
// Run drains the worker queue, calling fn for each item. Returns when ctx is cancelled.
|
||||
// Tests that want deterministic timing should drive the queue directly rather than going through Schedule + Run.
|
||||
func (s *Scheduler[T]) Run(ctx context.Context, fn func(T)) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case item := <-s.queue:
|
||||
fn(item)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestScheduler_PooledReuse(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s := NewScheduler[int](16)
|
||||
delivered := make(chan int, 256)
|
||||
go s.Run(ctx, func(item int) { delivered <- item })
|
||||
|
||||
const N = 100
|
||||
for i := 0; i < N; i++ {
|
||||
s.Schedule(ctx, i, time.Millisecond)
|
||||
}
|
||||
|
||||
deadline := time.After(2 * time.Second)
|
||||
got := 0
|
||||
for got < N {
|
||||
select {
|
||||
case <-delivered:
|
||||
got++
|
||||
case <-deadline:
|
||||
t.Fatalf("only %d/%d items delivered", got, N)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkScheduler_Schedule reports allocations per Schedule call.
|
||||
// In steady state the Scheduler's sync.Pool means we should see zero allocs per op once the pool warms up.
|
||||
func BenchmarkScheduler_Schedule(b *testing.B) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s := NewScheduler[int](b.N)
|
||||
go s.Run(ctx, func(int) {})
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Schedule(ctx, i, time.Microsecond)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkBareAfterFunc is the comparison baseline.
|
||||
// What we'd pay per Schedule if Punchy called time.AfterFunc directly without the pooled Scheduler.
|
||||
// Allocates a *time.Timer plus a closure each call.
|
||||
func BenchmarkBareAfterFunc(b *testing.B) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
queue := make(chan int, b.N)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-queue:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
i := i
|
||||
time.AfterFunc(time.Microsecond, func() {
|
||||
select {
|
||||
case queue <- i:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+38
-21
@@ -27,21 +27,20 @@ type SSHServer struct {
|
||||
commands *radix.Tree
|
||||
listener net.Listener
|
||||
|
||||
// Call the cancel() function to stop all active sessions
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
// ctx parents per-Run contexts. Cancelling it (e.g. via Control.Stop) tears the server down even
|
||||
// across reloads, since each Run derives a fresh child rather than reusing this one directly.
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
|
||||
func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen.
|
||||
// The ssh server's context is parented off the supplied ctx so cancelling it
|
||||
// (e.g. on Control.Stop) tears down active sessions and closes the listener.
|
||||
func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) {
|
||||
s := &SSHServer{
|
||||
trustedKeys: make(map[string]map[string]bool),
|
||||
l: l,
|
||||
commands: radix.New(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
cc := ssh.CertChecker{
|
||||
@@ -151,28 +150,51 @@ func (s *SSHServer) RegisterCommand(c *Command) {
|
||||
s.commands.Insert(c.Name, c)
|
||||
}
|
||||
|
||||
// Run begins listening and accepting connections
|
||||
// Run begins listening and accepting connections. Each invocation derives a fresh per-Run context
|
||||
// from the constructor-supplied ctx so a Stop+Run sequence (used by config reload) starts clean
|
||||
// rather than carrying a permanently-cancelled context across runs.
|
||||
func (s *SSHServer) Run(addr string) error {
|
||||
var err error
|
||||
s.listener, err = net.Listen("tcp", addr)
|
||||
if s.ctx.Err() != nil {
|
||||
return s.ctx.Err()
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// s.listener is the public handle Stop uses to interrupt the active run; listener (the local) is what
|
||||
// this run owns. They start equal but a fast reload may overwrite s.listener with the next run's
|
||||
// listener before this run's watcher fires, so each run must close its own listener via the local
|
||||
// reference.
|
||||
s.listener = listener
|
||||
|
||||
runCtx, cancel := context.WithCancel(s.ctx)
|
||||
defer cancel()
|
||||
|
||||
// Close the listener when this run's context is cancelled. That can come from the parent
|
||||
// (Control.Stop), from Run returning normally (defer cancel above), or transitively when a sibling
|
||||
// run cancels through Stop closing the listener. net.Listener.Close is idempotent so a duplicate
|
||||
// close from Stop is benign.
|
||||
go func() {
|
||||
<-runCtx.Done()
|
||||
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
s.l.Warn("Failed to close the sshd listener", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
s.l.Info("SSH server is listening", "sshListener", addr)
|
||||
|
||||
// Run loops until there is an error
|
||||
s.run()
|
||||
s.closeSessions()
|
||||
s.run(runCtx, listener)
|
||||
|
||||
s.l.Info("SSH server stopped listening")
|
||||
// We don't return an error because run logs for us
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSHServer) run() {
|
||||
func (s *SSHServer) run(ctx context.Context, listener net.Listener) {
|
||||
for {
|
||||
c, err := s.listener.Accept()
|
||||
c, err := listener.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
s.l.Warn("Error in listener, shutting down", "error", err)
|
||||
@@ -184,7 +206,7 @@ func (s *SSHServer) run() {
|
||||
// Ensure that a bad client doesn't hurt us by checking for the parent context
|
||||
// cancellation before calling NewServerConn, and forcing the socket to close when
|
||||
// the context is cancelled.
|
||||
sessionContext, sessionCancel := context.WithCancel(s.ctx)
|
||||
sessionContext, sessionCancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
<-sessionContext.Done()
|
||||
c.Close()
|
||||
@@ -227,14 +249,9 @@ func (s *SSHServer) run() {
|
||||
}
|
||||
|
||||
func (s *SSHServer) Stop() {
|
||||
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
|
||||
if s.listener != nil {
|
||||
if err := s.listener.Close(); err != nil {
|
||||
s.l.Warn("Failed to close the sshd listener", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHServer) closeSessions() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
+17
@@ -8,6 +8,23 @@ import (
|
||||
// How many timer objects should be cached
|
||||
const timerCacheMax = 50000
|
||||
|
||||
// TimerWheel is a hashed timing wheel: a fixed slot array indexed by (now + delay) % wheelLen,
|
||||
// with each slot a singly linked list of items due in that bucket.
|
||||
// Adds are O(1), Purges return items in arrival-within-slot order, and an internal cache of TimeoutItems
|
||||
// keeps steady-state inserts allocation-free.
|
||||
//
|
||||
// The TimerWheel does not handle concurrency or lifecycle on its own.
|
||||
// Callers drive Advance/Purge from their own ticker loop, take their own locks (or use LockingTimerWheel),
|
||||
// and decide whether to keep ticking when the wheel is empty.
|
||||
//
|
||||
// Pick a TimerWheel when scheduling is high-rate and uniform: line-rate conntrack inserts,
|
||||
// per-tunnel traffic checks at fixed intervals. O(1) insert plus the item cache means the hot path doesn't allocate.
|
||||
// Items added in the same tick are dispatched together when that slot rotates current,
|
||||
// which amortizes the cost of waking the worker.
|
||||
//
|
||||
// Pick a Scheduler when delay precision matters or scheduling is sparse or uneven.
|
||||
// The wheel rounds requested timeouts up to its tick resolution and clamps anything beyond its wheel duration;
|
||||
// both are silent in this implementation.
|
||||
type TimerWheel[T any] struct {
|
||||
// Current tick
|
||||
current int
|
||||
|
||||
+1
-2
@@ -5,12 +5,11 @@ package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
|
||||
+1
-2
@@ -8,12 +8,11 @@ package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
//go:build (amd64 || arm64) && !e2e_testing
|
||||
// +build amd64 arm64
|
||||
// +build !e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/wfp"
|
||||
)
|
||||
|
||||
// wrapWithWDFBypass wraps a Conn so that the first ReloadConfig consults listen.windows_bypass_wdf
|
||||
// and installs a WFP PERMIT filter for the listener's bound UDP port. The session is released when Close runs.
|
||||
func wrapWithWDFBypass(l *slog.Logger, conn Conn) Conn {
|
||||
return &bypassConn{Conn: conn, l: l}
|
||||
}
|
||||
|
||||
type bypassConn struct {
|
||||
Conn
|
||||
|
||||
l *slog.Logger
|
||||
installOnce sync.Once
|
||||
session *wfp.Session
|
||||
}
|
||||
|
||||
func (b *bypassConn) ReloadConfig(c *config.C) {
|
||||
b.installOnce.Do(func() {
|
||||
if !c.GetBool("listen.windows_bypass_wdf", true) {
|
||||
return
|
||||
}
|
||||
addr, err := b.Conn.LocalAddr()
|
||||
if err != nil {
|
||||
b.l.Warn("Failed to query listener port for WFP bypass", "error", err)
|
||||
return
|
||||
}
|
||||
s, err := wfp.PermitUDPPort(addr.Port())
|
||||
if err != nil {
|
||||
b.l.Warn("Failed to install WFP bypass filters for listener", "error", err)
|
||||
return
|
||||
}
|
||||
b.l.Info("Installed WFP filters bypassing Windows Defender Firewall on UDP listener port",
|
||||
"port", addr.Port())
|
||||
b.session = s
|
||||
})
|
||||
b.Conn.ReloadConfig(c)
|
||||
}
|
||||
|
||||
func (b *bypassConn) Close() error {
|
||||
if b.session != nil {
|
||||
b.session.Close()
|
||||
b.session = nil
|
||||
}
|
||||
return b.Conn.Close()
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import "log/slog"
|
||||
|
||||
// wrapWithWDFBypass is a no-op on windows-386 since we don't currently build for it.
|
||||
func wrapWithWDFBypass(_ *slog.Logger, conn Conn) Conn {
|
||||
return conn
|
||||
}
|
||||
+1
-2
@@ -7,12 +7,11 @@ package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
|
||||
+9
-4
@@ -19,13 +19,18 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int)
|
||||
return nil, fmt.Errorf("multiple udp listeners not supported on windows")
|
||||
}
|
||||
|
||||
var conn Conn
|
||||
rc, err := NewRIOListener(l, ip, port)
|
||||
if err == nil {
|
||||
return rc, nil
|
||||
conn = rc
|
||||
} else {
|
||||
l.Error("Falling back to standard udp sockets", "error", err)
|
||||
conn, err = NewGenericListener(l, ip, port, multi, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
l.Error("Falling back to standard udp sockets", "error", err)
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
return wrapWithWDFBypass(l, conn), nil
|
||||
}
|
||||
|
||||
func NewListenConfig(multi bool) net.ListenConfig {
|
||||
|
||||
@@ -0,0 +1,377 @@
|
||||
//go:build (amd64 || arm64) && !e2e_testing
|
||||
// +build amd64 arm64
|
||||
// +build !e2e_testing
|
||||
|
||||
// Package wfp installs Windows Filtering Platform (WFP) PERMIT filters in a dynamic, session-scoped sublayer.
|
||||
// Because WFP sits below Windows Defender Firewall, a high-weight permit at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4/V6 lets
|
||||
// the matching inbound traffic through regardless of WDF rules.
|
||||
//
|
||||
// Each Session owns its own engine handle. When the handle closes, every dynamic object added during the session
|
||||
// is auto-deleted by Windows, so there are no orphaned filters.
|
||||
//
|
||||
// Type definitions and constants are derived from the wireguard-windows firewall package (MIT).
|
||||
// Only the subset we exercise is reproduced.
|
||||
package wfp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// FWPM layer GUIDs (fwpmu.h).
|
||||
//
|
||||
// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = e1cd9fe7-f4b5-4273-96c0-592e487b8650
|
||||
// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = a3b42c97-9f04-4672-b87e-cee9c483257f
|
||||
var (
|
||||
fwpmLayerAleAuthRecvAcceptV4 = windows.GUID{
|
||||
Data1: 0xe1cd9fe7, Data2: 0xf4b5, Data3: 0x4273,
|
||||
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
|
||||
}
|
||||
fwpmLayerAleAuthRecvAcceptV6 = windows.GUID{
|
||||
Data1: 0xa3b42c97, Data2: 0x9f04, Data3: 0x4672,
|
||||
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
|
||||
}
|
||||
)
|
||||
|
||||
// FWPM_CONDITION_IP_LOCAL_INTERFACE = 4cd62a49-59c3-4969-b7f3-bda5d32890a4
|
||||
var fwpmConditionIPLocalInterface = windows.GUID{
|
||||
Data1: 0x4cd62a49, Data2: 0x59c3, Data3: 0x4969,
|
||||
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
|
||||
}
|
||||
|
||||
// FWPM_CONDITION_IP_PROTOCOL = 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
|
||||
var fwpmConditionIPProtocol = windows.GUID{
|
||||
Data1: 0x3971ef2b, Data2: 0x623e, Data3: 0x4f9a,
|
||||
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
|
||||
}
|
||||
|
||||
// FWPM_CONDITION_IP_LOCAL_PORT = 0c1ba1af-5765-453f-af22-a8f791ac775b
|
||||
var fwpmConditionIPLocalPort = windows.GUID{
|
||||
Data1: 0x0c1ba1af, Data2: 0x5765, Data3: 0x453f,
|
||||
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
|
||||
}
|
||||
|
||||
// IPPROTO_UDP from in.h.
|
||||
const ipprotoUDP uint8 = 17
|
||||
|
||||
// FWP_ACTION_TYPE values (fwptypes.h). PERMIT is terminating.
|
||||
const fwpActionPermit uint32 = 0x00001002 // 0x2 | FWP_ACTION_FLAG_TERMINATING(0x1000)
|
||||
|
||||
// FWP_DATA_TYPE values we use.
|
||||
const (
|
||||
fwpEmpty uint32 = 0
|
||||
fwpUint8 uint32 = 1
|
||||
fwpUint16 uint32 = 2
|
||||
fwpUint64 uint32 = 4
|
||||
)
|
||||
|
||||
// FWP_MATCH_TYPE values.
|
||||
const fwpMatchEqual uint32 = 0
|
||||
|
||||
// FWPM_SESSION flags.
|
||||
const fwpmSessionFlagDynamic uint32 = 0x1
|
||||
|
||||
// FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT prevents lower-priority filters in other sublayers,
|
||||
// notably Windows Defender Firewall's MPSSVC_WF sublayer, which shares our 0xFFFF weight from overriding this PERMIT.
|
||||
// Without it, a default WDF block at the same sublayer weight can still win arbitration.
|
||||
const fwpmFilterFlagClearActionRight uint32 = 0x8
|
||||
|
||||
// RPC authentication.
|
||||
// RPC_C_AUTHN_WINNT works on workgroup machines with no domain context
|
||||
// RPC_C_AUTHN_DEFAULT falls back through a chain that can land on something WFP doesn't accept on a fresh box.
|
||||
const rpcCAuthnWinNT uint32 = 10
|
||||
|
||||
// fwpByteBlob (FWP_BYTE_BLOB). 16 bytes on 64-bit.
|
||||
type fwpByteBlob struct {
|
||||
size uint32
|
||||
_ uint32 // padding
|
||||
data *uint8
|
||||
}
|
||||
|
||||
// fwpValue0 / FWP_CONDITION_VALUE0 layout. 16 bytes on 64-bit.
|
||||
// The union is pointer-sized; types <= 32 bits (UINT8/16/32, INT8/16/32, float) live inline in the low bytes
|
||||
// of `value`, while UINT64/INT64/double and aggregate types are stored *by pointer*, even on 64-bit, where the
|
||||
// union member is declared as UINT64*. So when populating an FWP_UINT64 condition, pass
|
||||
// uintptr(unsafe.Pointer(&luidVar)) instead of the LUID inline.
|
||||
type fwpValue0 struct {
|
||||
type_ uint32
|
||||
_ uint32 // padding before union to 8-byte alignment
|
||||
value uintptr
|
||||
}
|
||||
|
||||
// fwpmDisplayData0 / FWPM_DISPLAY_DATA0. 16 bytes on 64-bit.
|
||||
type fwpmDisplayData0 struct {
|
||||
name *uint16
|
||||
description *uint16
|
||||
}
|
||||
|
||||
// fwpmAction0 / FWPM_ACTION0. 20 bytes; no leading padding because actionType
|
||||
// is uint32 and GUID's first field is uint32.
|
||||
type fwpmAction0 struct {
|
||||
actionType uint32
|
||||
filterType windows.GUID
|
||||
}
|
||||
|
||||
// fwpmFilterCondition0. 40 bytes on 64-bit.
|
||||
type fwpmFilterCondition0 struct {
|
||||
fieldKey windows.GUID // 16
|
||||
matchType uint32 // 4
|
||||
_ uint32 // 4 padding
|
||||
conditionValue fwpValue0 // 16
|
||||
}
|
||||
|
||||
// fwpmFilter0. 200 bytes on 64-bit.
|
||||
type fwpmFilter0 struct {
|
||||
filterKey windows.GUID
|
||||
displayData fwpmDisplayData0
|
||||
flags uint32
|
||||
_ uint32 // padding before *GUID
|
||||
providerKey *windows.GUID
|
||||
providerData fwpByteBlob
|
||||
layerKey windows.GUID
|
||||
subLayerKey windows.GUID
|
||||
weight fwpValue0
|
||||
numFilterConditions uint32
|
||||
_ uint32 // padding before pointer
|
||||
filterCondition *fwpmFilterCondition0
|
||||
action fwpmAction0
|
||||
_ [4]byte // layout correction
|
||||
providerContextKey windows.GUID
|
||||
reserved *windows.GUID
|
||||
filterID uint64
|
||||
effectiveWeight fwpValue0
|
||||
}
|
||||
|
||||
// fwpmSublayer0. 72 bytes on 64-bit.
|
||||
type fwpmSublayer0 struct {
|
||||
subLayerKey windows.GUID
|
||||
displayData fwpmDisplayData0
|
||||
flags uint32
|
||||
_ uint32 // padding before *GUID
|
||||
providerKey *windows.GUID
|
||||
providerData fwpByteBlob
|
||||
weight uint16
|
||||
_ [6]byte // padding to 72 bytes
|
||||
}
|
||||
|
||||
// fwpmSession0. 72 bytes on 64-bit.
|
||||
type fwpmSession0 struct {
|
||||
sessionKey windows.GUID
|
||||
displayData fwpmDisplayData0
|
||||
flags uint32
|
||||
txnWaitTimeoutInMSec uint32
|
||||
processId uint32
|
||||
_ uint32 // padding before *SID
|
||||
sid *windows.SID
|
||||
username *uint16
|
||||
kernelMode uint8
|
||||
_ [7]byte // tail padding
|
||||
}
|
||||
|
||||
// fwpuclnt.dll bindings. Only the calls we use.
|
||||
var (
|
||||
modFwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
|
||||
procFwpmEngineOpen0 = modFwpuclnt.NewProc("FwpmEngineOpen0")
|
||||
procFwpmEngineClose0 = modFwpuclnt.NewProc("FwpmEngineClose0")
|
||||
procFwpmSubLayerAdd0 = modFwpuclnt.NewProc("FwpmSubLayerAdd0")
|
||||
procFwpmFilterAdd0 = modFwpuclnt.NewProc("FwpmFilterAdd0")
|
||||
)
|
||||
|
||||
// Session holds the WFP engine handle for a single bypass operation. The handle owns a dynamic session:
|
||||
// when it is closed, every WFP object added during the session (sublayer + filters) is automatically deleted by
|
||||
// Windows. That gives us correct cleanup even if the host process is killed hard between Permit* and Close.
|
||||
type Session struct {
|
||||
engine uintptr
|
||||
}
|
||||
|
||||
// Close releases the engine handle. Windows deletes every dynamic object (sublayer + filters) the session installed.
|
||||
// Safe to call on a nil receiver.
|
||||
func (s *Session) Close() {
|
||||
if s == nil || s.engine == 0 {
|
||||
return
|
||||
}
|
||||
procFwpmEngineClose0.Call(s.engine)
|
||||
s.engine = 0
|
||||
}
|
||||
|
||||
// PermitInterface installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to the given network
|
||||
// interface LUID. Inbound traffic on that interface bypasses Windows Defender Firewall.
|
||||
func PermitInterface(luid uint64) (*Session, error) {
|
||||
s, sublayerKey, err := newSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, luid); err != nil {
|
||||
s.Close()
|
||||
return nil, fmt.Errorf("add v4 filter: %w", err)
|
||||
}
|
||||
if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, luid); err != nil {
|
||||
s.Close()
|
||||
return nil, fmt.Errorf("add v6 filter: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// PermitUDPPort installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to UDP traffic with the
|
||||
// given local port. Inbound UDP to that port on any interface bypasses Windows Defender Firewall.
|
||||
func PermitUDPPort(port uint16) (*Session, error) {
|
||||
s, sublayerKey, err := newSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, port); err != nil {
|
||||
s.Close()
|
||||
return nil, fmt.Errorf("add v4 filter: %w", err)
|
||||
}
|
||||
if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, port); err != nil {
|
||||
s.Close()
|
||||
return nil, fmt.Errorf("add v6 filter: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func newSession() (*Session, windows.GUID, error) {
|
||||
engine, err := openDynamicEngine()
|
||||
if err != nil {
|
||||
return nil, windows.GUID{}, err
|
||||
}
|
||||
sublayerKey, err := registerSublayer(engine)
|
||||
if err != nil {
|
||||
procFwpmEngineClose0.Call(engine)
|
||||
return nil, windows.GUID{}, err
|
||||
}
|
||||
return &Session{engine: engine}, sublayerKey, nil
|
||||
}
|
||||
|
||||
func openDynamicEngine() (uintptr, error) {
|
||||
session := fwpmSession0{flags: fwpmSessionFlagDynamic}
|
||||
var engine uintptr
|
||||
r1, _, _ := procFwpmEngineOpen0.Call(
|
||||
0, // serverName == NULL (local)
|
||||
uintptr(rpcCAuthnWinNT),
|
||||
0, // authIdentity == NULL
|
||||
uintptr(unsafe.Pointer(&session)),
|
||||
uintptr(unsafe.Pointer(&engine)),
|
||||
)
|
||||
if r1 != 0 {
|
||||
return 0, fmt.Errorf("FwpmEngineOpen0: 0x%x", r1)
|
||||
}
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// registerSublayer adds a session-scoped sublayer with a freshly generated GUID, weight 0xFFFF so its filters arbitrate
|
||||
// above WDF's default sublayer. The sublayer is dynamic (no PERSISTENT flag) and goes away when the engine handle closes.
|
||||
func registerSublayer(engine uintptr) (windows.GUID, error) {
|
||||
key, err := windows.GenerateGUID()
|
||||
if err != nil {
|
||||
return windows.GUID{}, fmt.Errorf("GenerateGUID for sublayer: %w", err)
|
||||
}
|
||||
|
||||
name, _ := windows.UTF16PtrFromString("Nebula WDF bypass sublayer")
|
||||
desc, _ := windows.UTF16PtrFromString("Permit filters bypassing Windows Defender Firewall")
|
||||
sl := fwpmSublayer0{
|
||||
subLayerKey: key,
|
||||
displayData: fwpmDisplayData0{name: name, description: desc},
|
||||
weight: 0xFFFF,
|
||||
}
|
||||
r1, _, _ := procFwpmSubLayerAdd0.Call(
|
||||
engine,
|
||||
uintptr(unsafe.Pointer(&sl)),
|
||||
0, // sd == NULL
|
||||
)
|
||||
if r1 != 0 {
|
||||
return windows.GUID{}, fmt.Errorf("FwpmSubLayerAdd0: 0x%x", r1)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func addInterfaceFilter(engine uintptr, sublayerKey, layer windows.GUID, luid uint64) error {
|
||||
name, _ := windows.UTF16PtrFromString("Nebula allow interface inbound")
|
||||
desc, _ := windows.UTF16PtrFromString("Permits inbound traffic on a nebula interface")
|
||||
|
||||
// luid must remain addressable through the syscall -- FWP_UINT64 is stored
|
||||
// by pointer in the FWP_VALUE0 union.
|
||||
cond := fwpmFilterCondition0{
|
||||
fieldKey: fwpmConditionIPLocalInterface,
|
||||
matchType: fwpMatchEqual,
|
||||
conditionValue: fwpValue0{
|
||||
type_: fwpUint64,
|
||||
value: uintptr(unsafe.Pointer(&luid)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := fwpmFilter0{
|
||||
// filterKey left zero: WFP assigns one when the filter is added.
|
||||
displayData: fwpmDisplayData0{name: name, description: desc},
|
||||
flags: fwpmFilterFlagClearActionRight,
|
||||
layerKey: layer,
|
||||
subLayerKey: sublayerKey,
|
||||
weight: fwpValue0{type_: fwpUint8, value: uintptr(15)},
|
||||
numFilterConditions: 1,
|
||||
filterCondition: &cond,
|
||||
action: fwpmAction0{actionType: fwpActionPermit},
|
||||
}
|
||||
|
||||
r1, _, _ := procFwpmFilterAdd0.Call(
|
||||
engine,
|
||||
uintptr(unsafe.Pointer(&filter)),
|
||||
0, // sd == NULL
|
||||
0, // id == NULL
|
||||
)
|
||||
if r1 != 0 {
|
||||
return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addUDPPortFilter installs a PERMIT filter that matches (IP_PROTOCOL == UDP) AND (IP_LOCAL_PORT == port).
|
||||
// FWP_UINT8 and FWP_UINT16 are <= 32 bits so they live inline in the FWP_VALUE0 union.
|
||||
func addUDPPortFilter(engine uintptr, sublayerKey, layer windows.GUID, port uint16) error {
|
||||
name, _ := windows.UTF16PtrFromString("Nebula allow UDP port inbound")
|
||||
desc, _ := windows.UTF16PtrFromString("Permits inbound UDP to a nebula listener port")
|
||||
|
||||
conds := [2]fwpmFilterCondition0{
|
||||
{
|
||||
fieldKey: fwpmConditionIPProtocol,
|
||||
matchType: fwpMatchEqual,
|
||||
conditionValue: fwpValue0{
|
||||
type_: fwpUint8,
|
||||
value: uintptr(ipprotoUDP),
|
||||
},
|
||||
},
|
||||
{
|
||||
fieldKey: fwpmConditionIPLocalPort,
|
||||
matchType: fwpMatchEqual,
|
||||
conditionValue: fwpValue0{
|
||||
type_: fwpUint16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filter := fwpmFilter0{
|
||||
displayData: fwpmDisplayData0{name: name, description: desc},
|
||||
flags: fwpmFilterFlagClearActionRight,
|
||||
layerKey: layer,
|
||||
subLayerKey: sublayerKey,
|
||||
weight: fwpValue0{type_: fwpUint8, value: uintptr(15)},
|
||||
numFilterConditions: 2,
|
||||
filterCondition: &conds[0],
|
||||
action: fwpmAction0{actionType: fwpActionPermit},
|
||||
}
|
||||
|
||||
r1, _, _ := procFwpmFilterAdd0.Call(
|
||||
engine,
|
||||
uintptr(unsafe.Pointer(&filter)),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if r1 != 0 {
|
||||
return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user