diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9ce1d5e3..a5e8d397 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -209,10 +209,11 @@ jobs: id: create_release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REF_NAME: ${{ github.ref_name }} run: | cd artifacts gh release create \ --verify-tag \ - --title "Release ${{ github.ref_name }}" \ - "${{ github.ref_name }}" \ + --title "Release ${GITHUB_REF_NAME}" \ + "${GITHUB_REF_NAME}" \ SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index 24f899ab..3734db75 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -18,6 +18,8 @@ jobs: if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') name: Run extra smoke tests runs-on: ubuntu-latest + env: + VAGRANT_DEFAULT_PROVIDER: libvirt steps: - uses: actions/checkout@v6 @@ -30,8 +32,13 @@ jobs: - name: add hashicorp source run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list - - name: install vagrant - run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox + - name: install vagrant and libvirt + run: | + sudo apt-get update && sudo apt-get install -y vagrant libvirt-daemon-system libvirt-dev + sudo chmod 666 /dev/kvm + sudo usermod -aG libvirt $(whoami) + sudo chmod 666 /var/run/libvirt/libvirt-sock + vagrant plugin install vagrant-libvirt - name: freebsd-amd64 run: make smoke-vagrant/freebsd-amd64 @@ -42,10 +49,19 @@ jobs: - name: netbsd-amd64 run: make smoke-vagrant/netbsd-amd64 - - name: linux-386 - run: make smoke-vagrant/linux-386 - - name: linux-amd64-ipv6disable run: make smoke-vagrant/linux-amd64-ipv6disable + # linux-386 runs last because it requires disabling KVM to use VirtualBox, + # which prevents libvirt (used by the other tests) from working after this point. + - name: install virtualbox for i386 test + run: | + sudo apt-get install -y virtualbox + sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true + + - name: linux-386 + env: + VAGRANT_DEFAULT_PROVIDER: virtualbox + run: make smoke-vagrant/linux-386 + timeout-minutes: 30 diff --git a/.github/workflows/smoke/build-relay.sh b/.github/workflows/smoke/build-relay.sh index 70b07f4e..249e6c84 100755 --- a/.github/workflows/smoke/build-relay.sh +++ b/.github/workflows/smoke/build-relay.sh @@ -16,8 +16,10 @@ relay: am_relay: true EOF - export LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" - export REMOTE_ALLOW_LIST='{"172.17.0.4/32": false, "172.17.0.5/32": false}' + # TEST-NET-3 placeholder IPs; smoke-relay.sh seds them to real container IPs. + # Mapping: .2 lighthouse1, .3 host2, .4 host3, .5 host4. + export LIGHTHOUSES="192.168.100.1 203.0.113.2:4242" + export REMOTE_ALLOW_LIST='{"203.0.113.4/32": false, "203.0.113.5/32": false}' HOST="host2" ../genconfig.sh >host2.yml <host3.yml diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh index dcd132b0..b23516ee 100755 --- a/.github/workflows/smoke/build.sh +++ b/.github/workflows/smoke/build.sh @@ -5,9 +5,15 @@ set -e -x rm -rf ./build mkdir ./build -# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1 -# - We could make this better by launching the lighthouse first and then fetching what IP it is. -NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)" +# Smoke containers run on a dedicated docker network whose subnet is allocated +# at smoke time, not known at build time. Configs are written with TEST-NET-3 +# placeholder IPs (RFC 5737) and smoke.sh / smoke-vagrant.sh / smoke-relay.sh +# sed the real container IPs in before starting nebula. +# +# Placeholder mapping (last octet == fixed container slot): +# 203.0.113.2 -> lighthouse1, 203.0.113.3 -> host2, +# 203.0.113.4 -> host3, 203.0.113.5 -> host4. +LIGHTHOUSE_IP="203.0.113.2" ( cd build @@ -25,16 +31,16 @@ NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ ../genconfig.sh >lighthouse1.yml HOST="host2" \ - LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ + LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \ ../genconfig.sh >host2.yml HOST="host3" \ - LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ + LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \ INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host3.yml HOST="host4" \ - LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ + LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \ OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host4.yml diff --git a/.github/workflows/smoke/smoke-relay.sh b/.github/workflows/smoke/smoke-relay.sh index 9c113e18..aa1cd915 100755 --- a/.github/workflows/smoke/smoke-relay.sh +++ b/.github/workflows/smoke/smoke-relay.sh @@ -6,6 +6,8 @@ set -o pipefail mkdir -p logs +NETWORK="nebula-smoke-relay" + cleanup() { echo echo " *** cleanup" @@ -16,22 +18,53 @@ cleanup() { then docker kill lighthouse1 host2 host3 host4 fi + docker network rm "$NETWORK" >/dev/null 2>&1 } trap cleanup EXIT -docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test -docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test -docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test -docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test +# Create a dedicated smoke network with an explicit subnet (required for --ip +# below). Probe a short list of candidates so a locally-used range doesn't +# fail the whole test — we only need one to be free. +docker network rm "$NETWORK" >/dev/null 2>&1 || true +for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do + if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then + break + fi +done +if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then + echo "failed to create $NETWORK: every candidate subnet is in use" >&2 + exit 1 +fi -docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1, +# .3 host2, .4 host3, .5 host4 — matches the placeholders in build-relay.sh. +SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")" +PREFIX="${SUBNET%/*}" +PREFIX="${PREFIX%.*}" +LIGHTHOUSE_IP="$PREFIX.2" +HOST2_IP="$PREFIX.3" +HOST3_IP="$PREFIX.4" +HOST4_IP="$PREFIX.5" + +# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones. +for f in build/host2.yml build/host3.yml build/host4.yml; do + sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp" + mv "$f.tmp" "$f" +done + +docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test +docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" nebula:smoke-relay -config host2.yml -test +docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" nebula:smoke-relay -config host3.yml -test +docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" nebula:smoke-relay -config host4.yml -test + +docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 set +x @@ -76,7 +109,13 @@ docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' -sleep 5 + +# Wait up to 30s for all backgrounded jobs to exit rather than relying on a +# fixed sleep. +for _ in $(seq 1 30); do + [ -z "$(jobs -r)" ] && break + sleep 1 +done if [ "$(jobs -r)" ] then diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh index 1c1e3c50..e3863cb5 100755 --- a/.github/workflows/smoke/smoke-vagrant.sh +++ b/.github/workflows/smoke/smoke-vagrant.sh @@ -8,6 +8,8 @@ export VAGRANT_CWD="$PWD/vagrant-$1" mkdir -p logs +NETWORK="nebula-smoke" + cleanup() { echo echo " *** cleanup" @@ -19,21 +21,51 @@ cleanup() { docker kill lighthouse1 host2 fi vagrant destroy -f + docker network rm "$NETWORK" >/dev/null 2>&1 } trap cleanup EXIT +# Create a dedicated smoke network with an explicit subnet (required for --ip +# below). Probe a short list of candidates so a locally-used range doesn't +# fail the whole test — we only need one to be free. +docker network rm "$NETWORK" >/dev/null 2>&1 || true +for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do + if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then + break + fi +done +if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then + echo "failed to create $NETWORK: every candidate subnet is in use" >&2 + exit 1 +fi + +# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1, +# .3 host2 — matches the placeholders in build.sh. +SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")" +PREFIX="${SUBNET%/*}" +PREFIX="${PREFIX%.*}" +LIGHTHOUSE_IP="$PREFIX.2" +HOST2_IP="$PREFIX.3" + +# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones. +# This must happen before `vagrant up` rsyncs build/ into the VM for host3. +for f in build/host2.yml build/host3.yml; do + sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp" + mv "$f.tmp" "$f" +done + CONTAINER="nebula:${NAME:-smoke}" docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test -docker run --name host2 --rm "$CONTAINER" -config host2.yml -test +docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test vagrant up vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T -docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 15 @@ -96,7 +128,14 @@ vagrant ssh -c "ping -c1 192.168.100.2" -- -T vagrant ssh -c "sudo xargs kill /dev/null 2>&1 } trap cleanup EXIT +# Create a dedicated smoke network with an explicit subnet (required for --ip +# below). Probe a short list of candidates so a locally-used range doesn't +# fail the whole test — we only need one to be free. +docker network rm "$NETWORK" >/dev/null 2>&1 || true +for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do + if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then + break + fi +done +if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then + echo "failed to create $NETWORK: every candidate subnet is in use" >&2 + exit 1 +fi + +# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1, +# .3 host2, .4 host3, .5 host4 — matches the placeholders in build.sh. +SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")" +PREFIX="${SUBNET%/*}" +PREFIX="${PREFIX%.*}" +LIGHTHOUSE_IP="$PREFIX.2" +HOST2_IP="$PREFIX.3" +HOST3_IP="$PREFIX.4" +HOST4_IP="$PREFIX.5" + +# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones. +# build/lighthouse1.yml has no IPs to rewrite so it's skipped. +for f in build/host2.yml build/host3.yml build/host4.yml; do + sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp" + mv "$f.tmp" "$f" +done + CONTAINER="nebula:${NAME:-smoke}" docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test -docker run --name host2 --rm "$CONTAINER" -config host2.yml -test -docker run --name host3 --rm "$CONTAINER" -config host3.yml -test -docker run --name host4 --rm "$CONTAINER" -config host4.yml -test +docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test +docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" "$CONTAINER" -config host3.yml -test +docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" "$CONTAINER" -config host4.yml -test -docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 # grab tcpdump pcaps for debugging -docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & +docker exec lighthouse1 tcpdump -i tun0 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & -docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & +docker exec host2 tcpdump -i tun0 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & -docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & +docker exec host3 tcpdump -i tun0 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap & -docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & +docker exec host4 tcpdump -i tun0 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap & docker exec host2 ncat -nklv 0.0.0.0 2000 & docker exec host3 ncat -nklv 0.0.0.0 2000 & +docker exec host4 ncat -e '/usr/bin/echo helloagainfromhost4' -nkluv 0.0.0.0 4000 & docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 & @@ -119,17 +154,24 @@ echo echo " *** Testing conntrack" echo set -x -# host2 can ping host3 now that host3 pinged it first -docker exec host2 ping -c1 192.168.100.3 -# host4 can ping host2 once conntrack established -docker exec host2 ping -c1 192.168.100.4 -docker exec host4 ping -c1 192.168.100.2 + +# host4's outbound firewall only allows ICMP to the lighthouse, so host4 +# cannot initiate UDP to host2. Once host2 initiates a flow to host4:4000, +# conntrack must let host4's listener reply on that flow. If it doesn't, +# the echo back from host4 never reaches host2. +docker exec host2 sh -c "(/usr/bin/echo host2; sleep 2) | ncat -nuv 192.168.100.4 4000" | grep -q helloagainfromhost4 docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' -sleep 5 + +# Wait up to 30s for all backgrounded jobs to exit rather than relying on a +# fixed sleep. +for _ in $(seq 1 30); do + [ -z "$(jobs -r)" ] && break + sleep 1 +done if [ "$(jobs -r)" ] then diff --git a/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile index 89f94772..eeb9679e 100644 --- a/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile +++ b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile @@ -1,7 +1,7 @@ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| - config.vm.box = "ubuntu/jammy64" + config.vm.box = "bento/ubuntu-24.04" config.vm.synced_folder "../build", "/nebula" diff --git a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile index e4f41049..6dd26373 100644 --- a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile +++ b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile @@ -1,7 +1,7 @@ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| - config.vm.box = "generic/openbsd7" + config.vm.box = "DefinedNet/openbsd78" config.vm.synced_folder "../build", "/nebula", type: "rsync" end diff --git a/.golangci.yaml b/.golangci.yaml index bd82a952..be0513d4 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -2,7 +2,21 @@ version: "2" linters: default: none enable: + - sloglint - testifylint + settings: + sloglint: + # Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and + # the package-level slog equivalents. Use l.Log(ctx, level, ...) for + # custom levels instead of LogAttrs when you can. + # + # LogAttrs is also flagged by this rule because it takes ...slog.Attr; + # the few legitimate sites (where attrs is built up as a []slog.Attr) + # carry a //nolint:sloglint with rationale. + kv-only: true + # no-mixed-args is on by default: forbids mixing kv and attrs in one call. + # discard-handler is on by default (since Go 1.24): suggests + # slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil). exclusions: generated: lax presets: diff --git a/CHANGELOG.md b/CHANGELOG.md index 104b52e3..2ef7551f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.10.3] - 2026-02-06 + +### Security + +- Fix an issue where blocklist bypass is possible when using curve P256 since the signature can have 2 valid representations. + Both fingerprint representations will be tested against the blocklist. + Any newly issued P256 based certificates will have their signature clamped to the low-s form. + Nebula will assert the low-s signature form when validating certificates in a future version. [GHSA-69x3-g4r3-p962](https://github.com/slackhq/nebula/security/advisories/GHSA-69x3-g4r3-p962) + +### Changed + +- Improve error reporting if nebula fails to start due to a tun device naming issue. (#1588) + +## [1.10.2] - 2026-01-21 + +### Fixed + +- Fix panic when using `use_system_route_table` that was introduced in v1.10.1. (#1580) + +### Changed + +- Fix some typos in comments. (#1582) +- Dependency updates. (#1581) + ## [1.10.1] - 2026-01-16 See the [v1.10.1](https://github.com/slackhq/nebula/milestone/26?closed=1) milestone for a complete list of changes. @@ -764,7 +788,9 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.1...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.3...HEAD +[1.10.3]: https://github.com/slackhq/nebula/releases/tag/v1.10.3 +[1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2 [1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1 [1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0 [1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7 diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..00cd7bd1 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +#ECCN:Open Source diff --git a/README.md b/README.md index fab9cff1..7cbcb412 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for docker pull nebulaoss/nebula ``` -#### Mobile +#### Mobile ([source code](https://github.com/DefinedNet/mobile_nebula)) - [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200) - [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1) @@ -76,6 +76,8 @@ Nebula was created to provide a mechanism for groups of hosts to communicate sec ## Getting started (quickly) +**Don't want to manage your own PKI and lighthouses?** [Managed Nebula](https://www.defined.net/) from Defined Networking handles all of this for you. + To set up a Nebula network, you'll need: #### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use. diff --git a/bits.go b/bits.go index af11cc48..15bafd87 100644 --- a/bits.go +++ b/bits.go @@ -1,23 +1,43 @@ package nebula import ( + "context" + "fmt" + "log/slog" + "math" + mathbits "math/bits" + "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" ) +const bitsPerWord = 64 + +// Bits is a sliding-window anti-replay tracker. The window is stored as a +// circular bitmap packed into uint64 words (8x denser than a []bool), so a +// length-N window costs N/8 bytes. length must be a power of two. type Bits struct { length uint64 + lengthMask uint64 current uint64 - bits []bool + bits []uint64 lostCounter metrics.Counter dupeCounter metrics.Counter outOfWindowCounter metrics.Counter } -func NewBits(bits uint64) *Bits { +func NewBits(length uint64) *Bits { + if length == 0 || length&(length-1) != 0 { + panic(fmt.Sprintf("Bits length must be a power of two, got %d", length)) + } + + nWords := length / bitsPerWord + if nWords == 0 { + nWords = 1 + } b := &Bits{ - length: bits, - bits: make([]bool, bits, bits), + length: length, + lengthMask: length - 1, + bits: make([]uint64, nWords), current: 0, lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), @@ -25,88 +45,219 @@ func NewBits(bits uint64) *Bits { } // There is no counter value 0, mark it to avoid counting a lost packet later. - b.bits[0] = true - b.current = 0 + b.bits[0] = 1 return b } -func (b *Bits) Check(l *logrus.Logger, i uint64) bool { +func (b *Bits) get(i uint64) bool { + pos := i & b.lengthMask + //bit-shifting by 6 because i is a bit index, not a u64 index, and we need to find the u64 without bit in it + return b.bits[pos>>6]&(uint64(1)<<(pos&63)) != 0 +} + +func (b *Bits) set(i uint64) { + pos := i & b.lengthMask + b.bits[pos>>6] |= uint64(1) << (pos & 63) +} + +// clearRange clears `count` bits starting at circular position `startPos` +// (already masked to [0, length)) and returns how many of them were set +// before the clear. count must be in [1, length]. +func (b *Bits) clearRange(startPos, count uint64) uint64 { + wasSet := uint64(0) + if count >= b.length { + for _, w := range b.bits { + wasSet += uint64(mathbits.OnesCount64(w)) + } + clear(b.bits) + return wasSet + } + + pos := startPos + remaining := count + + // handle the potential partial word before pos becomes u64 aligned + word := pos >> 6 + bit := pos & 63 + take := uint64(64) - bit + if take > remaining { + take = remaining + } + if take > b.length-pos { + take = b.length - pos + } + var mask uint64 + if take == 64 { + mask = math.MaxUint64 + } else { + mask = ((uint64(1) << take) - 1) << bit + } + wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask)) + b.bits[word] &^= mask + remaining -= take + pos = (pos + take) & b.lengthMask + + // Clear whole words, keeping track of the number of set bits + for remaining >= 64 { + word = pos >> 6 + wasSet += uint64(mathbits.OnesCount64(b.bits[word])) + b.bits[word] = 0 + remaining -= 64 + pos = (pos + 64) & b.lengthMask + } + + // Clear the remaining partial word + if remaining > 0 { + word = pos >> 6 + mask = (uint64(1) << remaining) - 1 + wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask)) + b.bits[word] &^= mask + } + + return wasSet +} + +func (b *Bits) strictlyWithinWindow(i uint64) bool { + // Handle the case where the window hasn't slid yet. This avoids u64 underflow. + inWarmup := b.current < b.length + if i < b.length && inWarmup { + return true + } + + // Next, if the packet is in-window, see if we've seen it before + if i > b.current-b.length { + return true + } + return false //not within window! +} + +// Check returns true if i is within (or way out in front of) the window, and not a replay +func (b *Bits) Check(l *slog.Logger, i uint64) bool { // If i is the next number, return true. if i > b.current { return true } - // If i is within the window, check if it's been set already. - if i > b.current-b.length || i < b.length && b.current < b.length { - return !b.bits[i%b.length] + if b.strictlyWithinWindow(i) { + return !b.get(i) } // Not within the window - if l.Level >= logrus.DebugLevel { - l.Debugf("rejected a packet (top) %d %d\n", b.current, i) + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("rejected a packet (top)", "current", b.current, "incoming", i) } return false } -func (b *Bits) Update(l *logrus.Logger, i uint64) bool { - // If i is the next number, return true and update current. +// Update has three branches: +// - i == b.current+1: fast path; advance the cursor by one and lose-count +// the slot we just stomped (only past warmup; see the i > b.length guard +// below). +// - i > b.current+1: jump path; clear all slots between current and i +// (or up to a full window's worth, whichever is smaller) via clearRange, +// then mark i. Two arms here: a warmup arm that handles the very first +// window before the cursor has slid, and a steady-state arm that treats +// every cleared empty slot as a lost packet. +// - i <= b.current: in-window check for duplicates; out-of-window otherwise. +// +// NewBits seeds bits[0]=1 so counter 0 looks "received" — Update never +// clears that marker during warmup (clearRange skips position 0 when +// startPos=1), and once b.current >= b.length the marker is no longer +// consulted. The marker prevents a fictitious "lost" hit on the first real +// counter. +func (b *Bits) Update(l *slog.Logger, i uint64) bool { + // Fast path: i is the next expected counter. Split out so the function + // stays small and avoids paying for the slow paths' slog argument-build + // stack frame on every call. The bit read/test/write is inlined to + // touch the backing word once. if i == b.current+1 { - // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter - // The very first window can only be tracked as lost once we are on the 2nd window or greater - if b.bits[i%b.length] == false && i > b.length { + pos := i & b.lengthMask + word := pos >> 6 + mask := uint64(1) << (pos & 63) + w := b.bits[word] + if i > b.length && w&mask == 0 { b.lostCounter.Inc(1) } - b.bits[i%b.length] = true + b.bits[word] = w | mask b.current = i return true } + return b.updateSlow(l, i) +} +// updateSlow handles jumps, in-window backfill, dupes, and out-of-window. +func (b *Bits) updateSlow(l *slog.Logger, i uint64) bool { // If i is a jump, adjust the window, record lost, update current, and return true if i > b.current { - lost := int64(0) - // Zero out the bits between the current and the new counter value, limited by the window size, - // since the window is shifting - for n := b.current + 1; n <= min(i, b.current+b.length); n++ { - if b.bits[n%b.length] == false && n > b.length { - lost++ + end := i + if end > b.current+b.length { + end = b.current + b.length + } + count := end - b.current + startPos := (b.current + 1) & b.lengthMask + + var lost int64 + if b.current >= b.length { + // Steady state: every cleared slot is past warmup, so any unset + // bit we evict is a lost packet from the previous cycle. + wasSet := b.clearRange(startPos, count) + lost = int64(count) - int64(wasSet) + } else { + // Warmup (the very first window). Some cleared slots represent + // packets <= length where eviction is not "lost" in the usual + // sense. This branch is taken at most once per connection so we + // don't bother optimizing it. + for n := b.current + 1; n <= end; n++ { + if !b.get(n) && n > b.length { + lost++ + } } - b.bits[n%b.length] = false + b.clearRange(startPos, count) } - // Only record any skipped packets as a result of the window moving further than the window length - // Any loss within the new window will be accounted for in future calls - lost += max(0, int64(i-b.current-b.length)) + // Anything past the new window can never be backfilled, so it's lost. + if i > b.current+b.length { + lost += int64(i - b.current - b.length) + } b.lostCounter.Inc(lost) - b.bits[i%b.length] = true + b.set(i) b.current = i return true } - // If i is within the current window but below the current counter, - // Check to see if it's a duplicate - if i > b.current-b.length || i < b.length && b.current < b.length { - if b.current == i || b.bits[i%b.length] == true { - if l.Level >= logrus.DebugLevel { - l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). - Debug("Receive window") + // If i is within the current window but below the current counter, check to see if it's a duplicate + if b.strictlyWithinWindow(i) { + pos := i & b.lengthMask + word := pos >> 6 + mask := uint64(1) << (pos & 63) + w := b.bits[word] + if b.current == i || w&mask != 0 { + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Receive window", + "accepted", false, + "currentCounter", b.current, + "incomingCounter", i, + "reason", "duplicate", + ) } b.dupeCounter.Inc(1) return false } - b.bits[i%b.length] = true + b.bits[word] = w | mask return true } // In all other cases, fail and don't change current. b.outOfWindowCounter.Inc(1) - if l.Level >= logrus.DebugLevel { - l.WithField("accepted", false). - WithField("currentCounter", b.current). - WithField("incomingCounter", i). - WithField("reason", "nonsense"). - Debug("Receive window") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Receive window", + "accepted", false, + "currentCounter", b.current, + "incomingCounter", i, + "reason", "nonsense", + ) } return false } diff --git a/bits_test.go b/bits_test.go index 3504cefa..da44c92a 100644 --- a/bits_test.go +++ b/bits_test.go @@ -7,61 +7,79 @@ import ( "github.com/stretchr/testify/assert" ) +// snapshot returns the bitmap as a []bool of length b.length, for readable +// test assertions against the now-packed []uint64 storage. +func (b *Bits) snapshot() []bool { + out := make([]bool, b.length) + for i := uint64(0); i < b.length; i++ { + out[i] = b.get(i) + } + return out +} + +func TestBitsRequiresPowerOfTwo(t *testing.T) { + assert.Panics(t, func() { NewBits(10) }) + assert.Panics(t, func() { NewBits(0) }) + assert.NotPanics(t, func() { NewBits(1) }) + assert.NotPanics(t, func() { NewBits(16) }) + assert.NotPanics(t, func() { NewBits(1024) }) + assert.NotPanics(t, func() { NewBits(16384) }) +} + func TestBits(t *testing.T) { l := test.NewLogger() - b := NewBits(10) - - // make sure it is the right size - assert.Len(t, b.bits, 10) + b := NewBits(16) + assert.EqualValues(t, 16, b.length) // This is initialized to zero - receive one. This should work. assert.True(t, b.Check(l, 1)) assert.True(t, b.Update(l, 1)) assert.EqualValues(t, 1, b.current) - g := []bool{true, true, false, false, false, false, false, false, false, false} - assert.Equal(t, g, b.bits) + g := []bool{true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) // Receive two assert.True(t, b.Check(l, 2)) assert.True(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) - g = []bool{true, true, true, false, false, false, false, false, false, false} - assert.Equal(t, g, b.bits) + g = []bool{true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) // Receive two again - it will fail assert.False(t, b.Check(l, 2)) assert.False(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) - // Jump ahead to 15, which should clear everything and set the 6th element - assert.True(t, b.Check(l, 15)) - assert.True(t, b.Update(l, 15)) - assert.EqualValues(t, 15, b.current) - g = []bool{false, false, false, false, false, true, false, false, false, false} - assert.Equal(t, g, b.bits) + // Jump ahead to 25, which clears the window and sets slot 25%16 = 9. + assert.True(t, b.Check(l, 25)) + assert.True(t, b.Update(l, 25)) + assert.EqualValues(t, 25, b.current) + g = []bool{false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) - // Mark 14, which is allowed because it is in the window - assert.True(t, b.Check(l, 14)) - assert.True(t, b.Update(l, 14)) - assert.EqualValues(t, 15, b.current) - g = []bool{false, false, false, false, true, true, false, false, false, false} - assert.Equal(t, g, b.bits) + // Mark 24, which is in window (current 25, length 16, window covers [10,25]). + assert.True(t, b.Check(l, 24)) + assert.True(t, b.Update(l, 24)) + assert.EqualValues(t, 25, b.current) + g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) - // Mark 5, which is not allowed because it is not in the window + // Mark 5, not allowed because 5 <= current-length (25-16=9). assert.False(t, b.Check(l, 5)) assert.False(t, b.Update(l, 5)) - assert.EqualValues(t, 15, b.current) - g = []bool{false, false, false, false, true, true, false, false, false, false} - assert.Equal(t, g, b.bits) + assert.EqualValues(t, 25, b.current) + g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) - // make sure we handle wrapping around once to the current position - b = NewBits(10) + // Make sure we handle wrapping around once to the same slot. With + // length=16, packets 1 and 17 share slot 1. + b = NewBits(16) assert.True(t, b.Update(l, 1)) - assert.True(t, b.Update(l, 11)) - assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits) + assert.True(t, b.Update(l, 17)) + assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}, b.snapshot()) // Walk through a few windows in order - b = NewBits(10) + b = NewBits(16) for i := uint64(1); i <= 100; i++ { assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i) @@ -72,24 +90,31 @@ func TestBits(t *testing.T) { func TestBitsLargeJumps(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + + // length=16. Update(55) from current=0: + // warmup, per-bit loop sees no n>16 with unset bits (slot 0 was set by + // NewBits and gets re-evaluated when n=16; n=16 is not strictly > 16), + // so the loop contributes 0. The jump exceeds the window so we record + // 55 - 0 - 16 = 39 packets fell out the back. + b := NewBits(16) b.lostCounter.Clear() + assert.True(t, b.Update(l, 55)) + assert.Equal(t, int64(39), b.lostCounter.Count()) - b = NewBits(10) - b.lostCounter.Clear() - assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54 - assert.Equal(t, int64(45), b.lostCounter.Count()) + // Update(100): clears 16 slots starting at slot 56%16=8. Only slot 7 (for + // packet 55) was set, so 16 - 1 = 15 evicted slots had unset bits. + // Plus 100 - 55 - 16 = 29 packets fell past the window. Total 44. + assert.True(t, b.Update(l, 100)) + assert.Equal(t, int64(39+44), b.lostCounter.Count()) - assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99 - assert.Equal(t, int64(89), b.lostCounter.Count()) - - assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199 - assert.Equal(t, int64(188), b.lostCounter.Count()) + // Update(200): same shape: 16 - 1 = 15 evicted unset, plus 200 - 100 - 16 = 84 past window. Total 99. + assert.True(t, b.Update(l, 200)) + assert.Equal(t, int64(39+44+99), b.lostCounter.Count()) } func TestBitsDupeCounter(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() @@ -114,120 +139,117 @@ func TestBitsDupeCounter(t *testing.T) { func TestBitsOutOfWindowCounter(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() + // Jump to 20 (warmup branch + 4 past-window packets). assert.True(t, b.Update(l, 20)) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) - assert.True(t, b.Update(l, 21)) - assert.True(t, b.Update(l, 22)) - assert.True(t, b.Update(l, 23)) - assert.True(t, b.Update(l, 24)) - assert.True(t, b.Update(l, 25)) - assert.True(t, b.Update(l, 26)) - assert.True(t, b.Update(l, 27)) - assert.True(t, b.Update(l, 28)) - assert.True(t, b.Update(l, 29)) + // 9 single-step advances, each evicts a slot whose bit was cleared during + // the jump above and whose value was never seen, so each contributes 1 + // to lostCounter. + for n := uint64(21); n <= 29; n++ { + assert.True(t, b.Update(l, n)) + } assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) + // 0 is below current-length (29-16=13) so it falls outside the window. assert.False(t, b.Update(l, 0)) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) - assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost + // 4 from the Update(20) jump + 9 from 21..29. + assert.Equal(t, int64(13), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) } func TestBitsLostCounter(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(l, 20)) - assert.True(t, b.Update(l, 21)) - assert.True(t, b.Update(l, 22)) - assert.True(t, b.Update(l, 23)) - assert.True(t, b.Update(l, 24)) - assert.True(t, b.Update(l, 25)) - assert.True(t, b.Update(l, 26)) - assert.True(t, b.Update(l, 27)) - assert.True(t, b.Update(l, 28)) - assert.True(t, b.Update(l, 29)) - assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost + // Walk 20..29 like the original, just with a bigger window. Same + // reasoning as TestBitsOutOfWindowCounter: 4 past-window from Update(20), + // then 9 more from the unit advances. + for n := uint64(20); n <= 29; n++ { + assert.True(t, b.Update(l, n)) + } + assert.Equal(t, int64(13), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) - b = NewBits(10) + b = NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(l, 9)) - assert.Equal(t, int64(0), b.lostCounter.Count()) - // 10 will set 0 index, 0 was already set, no lost packets - assert.True(t, b.Update(l, 10)) - assert.Equal(t, int64(0), b.lostCounter.Count()) - // 11 will set 1 index, 1 was missed, we should see 1 packet lost - assert.True(t, b.Update(l, 11)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - // Now let's fill in the window, should end up with 8 lost packets - assert.True(t, b.Update(l, 12)) - assert.True(t, b.Update(l, 13)) - assert.True(t, b.Update(l, 14)) + // Update(15) clears the warmup window (no lost), sets slot 15. assert.True(t, b.Update(l, 15)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + + // Update(16): slot 0 was already set (NewBits seeded it), and 16 is not + // strictly > length, so nothing is recorded as lost. assert.True(t, b.Update(l, 16)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + + // Update(17): we jumped straight from 0 to 15, so slot 1 was cleared + // (and never re-set). 17 > 16 is past warmup, so packet 1 is recorded lost. assert.True(t, b.Update(l, 17)) - assert.True(t, b.Update(l, 18)) - assert.True(t, b.Update(l, 19)) - assert.Equal(t, int64(8), b.lostCounter.Count()) + assert.Equal(t, int64(1), b.lostCounter.Count()) - // Jump ahead by a window size - assert.True(t, b.Update(l, 29)) - assert.Equal(t, int64(8), b.lostCounter.Count()) - // Now lets walk ahead normally through the window, the missed packets should fill in - assert.True(t, b.Update(l, 30)) - assert.True(t, b.Update(l, 31)) - assert.True(t, b.Update(l, 32)) - assert.True(t, b.Update(l, 33)) - assert.True(t, b.Update(l, 34)) - assert.True(t, b.Update(l, 35)) - assert.True(t, b.Update(l, 36)) - assert.True(t, b.Update(l, 37)) - assert.True(t, b.Update(l, 38)) - // 39 packets tracked, 22 seen, 17 lost - assert.Equal(t, int64(17), b.lostCounter.Count()) + // Fill in 18..30 in single steps. Each i evicts slot i%16. Slots 2..14 + // were all cleared during Update(15), and we never re-set any of them, + // so each i in 18..30 is a fresh lost packet — 13 more. + for n := uint64(18); n <= 30; n++ { + assert.True(t, b.Update(l, n)) + } + assert.Equal(t, int64(14), b.lostCounter.Count()) - // Jump ahead by 2 windows, should have recording 1 full window missing - assert.True(t, b.Update(l, 58)) - assert.Equal(t, int64(27), b.lostCounter.Count()) - // Now lets walk ahead normally through the window, the missed packets should fill in from this window - assert.True(t, b.Update(l, 59)) - assert.True(t, b.Update(l, 60)) - assert.True(t, b.Update(l, 61)) - assert.True(t, b.Update(l, 62)) - assert.True(t, b.Update(l, 63)) - assert.True(t, b.Update(l, 64)) - assert.True(t, b.Update(l, 65)) - assert.True(t, b.Update(l, 66)) - assert.True(t, b.Update(l, 67)) - // 68 packets tracked, 32 seen, 36 missed - assert.Equal(t, int64(36), b.lostCounter.Count()) + // Jump ahead by exactly one window size. + assert.True(t, b.Update(l, 46)) + // end = min(46, 30+16) = 46, count = 16, all slots cleared. Before the + // jump every slot 0..15 had been set (Update(15), (16), (17), 18..30), + // so wasSet=16 and 46 == current+length means no past-window slack: + // lost contribution = 0. + assert.Equal(t, int64(14), b.lostCounter.Count()) + + // Walk 47..55. The Update(46) jump cleared every slot, so only slot 14 + // (for packet 46) is set when we start. Each subsequent unit step lands + // on a slot that was cleared and is past warmup, so it counts as lost. + // 9 more = 23. + for n := uint64(47); n <= 55; n++ { + assert.True(t, b.Update(l, n)) + } + assert.Equal(t, int64(23), b.lostCounter.Count()) + + // Jump ahead by two windows: clears the window plus past-window loss. + assert.True(t, b.Update(l, 87)) + // current=55, length=16. end = min(87, 71) = 71. count=16, all slots + // cleared. Slots set before the clear are slots 14,15,0..7 (10 total). + // Lost from clear = 16 - 10 = 6. Past window: 87 - 55 - 16 = 16. +22. + assert.Equal(t, int64(45), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } func TestBitsLostCounterIssue1(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() + // Receive 4, backfill 1, then 9, 2, 3, 5, 6, 7 (skip 8), 10, 11, 14. + // Then jump to 25 — slot 25%16=9 is being evicted, but it had been set + // (we received packet 9), so no spurious lost increment. The original + // regression was about double-counting a missing packet when its slot + // got cleared on a jump. With the jump path now using clearRange's + // word-level wasSet count, the same semantics hold. assert.True(t, b.Update(l, 4)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 1)) @@ -244,7 +266,7 @@ func TestBitsLostCounterIssue1(t *testing.T) { assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 7)) assert.Equal(t, int64(0), b.lostCounter.Count()) - // assert.True(t, b.Update(l, 8)) + // Skip packet 8. assert.True(t, b.Update(l, 10)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 11)) @@ -252,9 +274,23 @@ func TestBitsLostCounterIssue1(t *testing.T) { assert.True(t, b.Update(l, 14)) assert.Equal(t, int64(0), b.lostCounter.Count()) - // Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter - assert.True(t, b.Update(l, 19)) + + // Jump to 25. With length=16, slot 25%16=9 corresponds to packet 9 + // (which we DID receive), so its bit is set and no lost++ from that + // eviction. The trace below shows the only loss is packet 8. + assert.True(t, b.Update(l, 25)) + // current was 14, i=25. end=min(25,30)=25. count=11. startPos=15. + // steady? current=14<16, so warmup branch: per-bit n=15..25, count those + // with !get(n) AND n>16. n=17..25 are >16. Among slots 17%16=1..25%16=9 + // did we set slots 1..9 (packets 1..9)? Yes for all but slot 8 (packet 8 + // was skipped). n=24 maps to slot 8 which is FALSE → lost++. All other + // n in 17..25 map to slots that are set. n=16 is not strictly > 16. So + // lost = 1. assert.Equal(t, int64(1), b.lostCounter.Count()) + + // Fill in 12, 13, 15, 16. Each is below current=25 (in-window). 16 must + // recheck slot 0 — it was set by NewBits and then cleared by the + // Update(25) jump, so 16 backfills cleanly. assert.True(t, b.Update(l, 12)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 13)) @@ -263,29 +299,140 @@ func TestBitsLostCounterIssue1(t *testing.T) { assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 16)) assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 17)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 18)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 20)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 21)) - // We missed packet 8 above + // We missed packet 8 above and that loss is still recorded once, never + // double-counted, never zeroed. assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } -func BenchmarkBits(b *testing.B) { - z := NewBits(10) - for n := 0; n < b.N; n++ { - for i := range z.bits { - z.bits[i] = true - } - for i := range z.bits { - z.bits[i] = false - } +// TestBitsWarmupOvershoot exercises the jump path's warmup arm with an +// overshoot past one full window. NewBits leaves current=0 with only slot 0 +// "set" by the marker. Jumping straight to length+k must (a) clear every +// slot the jump straddles, (b) count only past-window slack (not the +// in-window slots, which never had a "lost" tenant during warmup), and +// (c) leave the cursor at the new counter so subsequent unit advances +// count from steady state. The marker bit at slot 0 is irrelevant once +// current >= length. +func TestBitsWarmupOvershoot(t *testing.T) { + l := test.NewLogger() + b := NewBits(16) + b.lostCounter.Clear() + // Jump from current=0 to i=20 (length=16, overshoot=4). + // Warmup arm: counts slots in [1..16] where bit unset and n>length. + // Only n=16 was unset and >length: but slot 16%16=0 is the marker, + // so b.get(16) reads bits[0]=1 and skips. Result: 0 lost from the loop. + // Past-window: i - current - length = 20 - 0 - 16 = 4 lost. + assert.True(t, b.Update(l, 20)) + assert.Equal(t, int64(4), b.lostCounter.Count()) + assert.Equal(t, uint64(20), b.current) + + // Steady state now (current=20 >= length=16). Unit advance to 21 + // stomps slot 21%16=5, which was cleared by the jump and not reset, + // so this is +1 lost. + assert.True(t, b.Update(l, 21)) + assert.Equal(t, int64(5), b.lostCounter.Count()) +} + +// TestBitsCheckAcrossWarmupBoundary pins the underflow trick in Check's +// in-window clause. While in warmup, b.current-b.length underflows uint64 +// to a huge value so the first OR-clause is always false; the second +// clause (i < length && current < length) carries the in-window check. +// Once current >= length the regimes flip cleanly. +func TestBitsCheckAcrossWarmupBoundary(t *testing.T) { + l := test.NewLogger() + b := NewBits(16) + + // Warmup: current=0. Check(0) must read the marker (set) and return false. + assert.False(t, b.Check(l, 0), "marker slot should look already-received") + // Warmup: any 0 < i < length is in-window and unset → accepted. + for i := uint64(1); i < 16; i++ { + assert.True(t, b.Check(l, i), "warmup in-window i=%d should be accepted", i) + } + // Warmup: i >= length but > current is "next number" so accepted. + assert.True(t, b.Check(l, 16)) + assert.True(t, b.Check(l, 1_000_000)) + + // Cross into steady state. + assert.True(t, b.Update(l, 100)) + // Now current=100, length=16. In-window range is [85..100]. + // 84 is just outside: the underflow clause activates; 84 > 100-16=84 is false. + // And the warmup clause is false (current >= length). So out of window. + assert.False(t, b.Check(l, 84)) + // 85 sits at the boundary. 85 > 84 is true → in window, unset → accept. + assert.True(t, b.Check(l, 85)) + // 100 is current itself; not strictly greater, in-window, but already set. + assert.False(t, b.Check(l, 100)) + // Way out: clearly out of window. + assert.False(t, b.Check(l, 50)) +} + +// TestBitsMarkerInvariant verifies the seeded bits[0]=1 marker behaves +// correctly across warmup and beyond. Update should never clear the marker +// during warmup (clearRange skips position 0 when startPos=1), and once +// current >= length the marker is no longer consulted by Check/Update on +// the live path — but it must still report counter 0 as a duplicate while +// we are in warmup. +func TestBitsMarkerInvariant(t *testing.T) { + l := test.NewLogger() + b := NewBits(8) + + // Counter 0 is the seeded marker; Check sees it as already received. + assert.False(t, b.Check(l, 0)) + // Update(0) at current=0 hits the duplicate branch. + b.dupeCounter.Clear() + assert.False(t, b.Update(l, 0)) + assert.Equal(t, int64(1), b.dupeCounter.Count()) + + // Walk forward through warmup; the marker must remain set. + for n := uint64(1); n <= 7; n++ { + assert.True(t, b.Update(l, n)) + } + // Position 0 (the marker) should still read as set because we never + // cleared it; Update(0) still looks like a duplicate. + assert.False(t, b.Check(l, 0)) + + // Cross into steady state with a unit advance to 8: pos=0, evicts the + // marker bit. The lost-counter guard (i > b.length) is false (8 == 8), + // so this advance does NOT charge a lost packet — exactly what the + // marker is there to prevent. + b.lostCounter.Clear() + assert.True(t, b.Update(l, 8)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + // The slot at pos 0 is now occupied by counter 8. + assert.False(t, b.Check(l, 8)) +} + +// BenchmarkBitsUpdateInOrder is the steady-state hot path: each call is +// i == current+1. +func BenchmarkBitsUpdateInOrder(b *testing.B) { + l := test.NewLogger() + z := NewBits(16384) + for n := 0; n < b.N; n++ { + z.Update(l, uint64(n)+1) + } +} + +// BenchmarkBitsUpdateReorder simulates light reorder within the window: +// every other packet arrives one slot behind its predecessor (forces the +// in-window backfill branch). +func BenchmarkBitsUpdateReorder(b *testing.B) { + l := test.NewLogger() + z := NewBits(16384) + for n := 0; n < b.N; n++ { + base := uint64(n) * 2 + z.Update(l, base+2) + z.Update(l, base+1) + } +} + +// BenchmarkBitsUpdateLargeJumps stresses the clearRange word-level path. +func BenchmarkBitsUpdateLargeJumps(b *testing.B) { + l := test.NewLogger() + z := NewBits(16384) + for n := 0; n < b.N; n++ { + z.Update(l, uint64(n+1)*1000) } } diff --git a/boring.go b/boring.go index 9cd9d37f..abe403fc 100644 --- a/boring.go +++ b/boring.go @@ -1,5 +1,4 @@ //go:build boringcrypto -// +build boringcrypto package nebula diff --git a/cert/ca_pool.go b/cert/ca_pool.go index 2bf480f2..792f8e66 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -1,11 +1,14 @@ package cert import ( + "bufio" + "bytes" + "encoding/pem" "errors" "fmt" + "io" "net/netip" "slices" - "strings" "time" ) @@ -29,22 +32,46 @@ func NewCAPool() *CAPool { // If the pool contains any expired certificates, an ErrExpired will be // returned along with the pool. The caller must handle any such errors. func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { + return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs)) +} + +// NewCAPoolFromPEMReader will create a new CA pool from the provided reader. +// The reader must contain a PEM-encoded set of nebula certificates. +func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) { pool := NewCAPool() - var err error + var expired bool - for { - caPEMs, err = pool.AddCAFromPEM(caPEMs) - if errors.Is(err, ErrExpired) { - expired = true - err = nil + + scanner := bufio.NewScanner(r) + scanner.Split(SplitPEM) + + for scanner.Scan() { + pemBytes := scanner.Bytes() + + block, rest := pem.Decode(pemBytes) + if len(bytes.TrimSpace(rest)) > 0 { + return nil, ErrInvalidPEMBlock } + if block == nil { + return nil, ErrInvalidPEMBlock + } + + c, err := unmarshalCertificateBlock(block) if err != nil { return nil, err } - if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { - break + + err = pool.AddCA(c) + if errors.Is(err, ErrExpired) { + expired = true + continue + } else if err != nil { + return nil, err } } + if err := scanner.Err(); err != nil { + return nil, ErrInvalidPEMBlock + } if expired { return pool, ErrExpired @@ -141,10 +168,23 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti return nil, err } + // Pre nebula v1.10.3 could generate signatures in either high or low s form and validation + // of signatures allowed for either. Nebula v1.10.3 and beyond clamps signature generation to low-s form + // but validation still allows for either. Since a change in the signature bytes affects the fingerprint, we + // need to test both forms until such a time comes that we enforce low-s form on signature validation. + fp2, err := CalculateAlternateFingerprint(c) + if err != nil { + return nil, fmt.Errorf("could not calculate alternate fingerprint to verify: %w", err) + } + if fp2 != "" && ncp.IsBlocklisted(fp2) { + return nil, ErrBlockListed + } + cc := CachedCertificate{ Certificate: c, InvertedGroups: make(map[string]struct{}), Fingerprint: fp, + fingerprint2: fp2, signerFingerprint: signer.Fingerprint, } @@ -158,6 +198,11 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti // VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and // is a cheaper operation to perform as a result. func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error { + // Check any available alternate fingerprint forms for this certificate, re P256 high-s/low-s + if c.fingerprint2 != "" && ncp.IsBlocklisted(c.fingerprint2) { + return ErrBlockListed + } + _, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint) return err } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index b0fdd5fb..ab173228 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -1,10 +1,14 @@ package cert import ( + "bytes" + "io" "net/netip" + "strings" "testing" "time" + "github.com/slackhq/nebula/cert/p256" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -111,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe assert.Len(t, ppppp.CAs, 1) } +// oneByteReader wraps a reader to return at most 1 byte per Read call, +// exercising the streaming accumulation logic in NewCAPoolFromPEMReader. +type oneByteReader struct { + r io.Reader +} + +func (o *oneByteReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + return o.r.Read(p[:1]) +} + +func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) { + pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil)) + require.NoError(t, err) + assert.Empty(t, pool.CAs) + + pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n ")) + require.NoError(t, err) + assert.Empty(t, pool.CAs) +} + +func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) { + ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + + bundle := append(pem1, pem2...) + pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)}) + require.NoError(t, err) + assert.Len(t, pool.CAs, 2) + + fp1, err := ca1.Fingerprint() + require.NoError(t, err) + fp2, err := ca2.Fingerprint() + require.NoError(t, err) + + assert.Contains(t, pool.CAs, fp1) + assert.Contains(t, pool.CAs, fp2) +} + +func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) { + _, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata")) + assert.ErrorIs(t, err, ErrInvalidPEMBlock) +} + +func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) { + _, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + + bundle := append(pem1, []byte("some trailing garbage")...) + _, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle)) + assert.ErrorIs(t, err, ErrInvalidPEMBlock) +} + func TestCertificateV1_Verify(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) @@ -170,6 +228,15 @@ func TestCertificateV1_VerifyP256(t *testing.T) { _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") + // Create a copy of the cert and swap to the alternate form for the signature + nc := c.Copy() + b, err := p256.Swap(c.Signature()) + require.NoError(t, err) + require.NoError(t, nc.(*certificateV1).setSignature(b)) + + _, err = caPool.VerifyCertificate(time.Now(), nc) + require.EqualError(t, err, "certificate is in the block list") + caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) @@ -187,7 +254,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { require.NoError(t, err) caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) + b, err = caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) @@ -196,7 +263,17 @@ func TestCertificateV1_VerifyP256(t *testing.T) { }) c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - _, err = caPool.VerifyCertificate(time.Now(), c) + cc, err := caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Reset the blocklist and block the alternate form fingerprint + caPool.ResetCertBlocklist() + caPool.BlocklistFingerprint(cc.fingerprint2) + err = caPool.VerifyCachedCertificate(time.Now(), cc) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + err = caPool.VerifyCachedCertificate(time.Now(), cc) require.NoError(t, err) } @@ -394,6 +471,15 @@ func TestCertificateV2_VerifyP256(t *testing.T) { _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") + // Create a copy of the cert and swap to the alternate form for the signature + nc := c.Copy() + b, err := p256.Swap(c.Signature()) + require.NoError(t, err) + require.NoError(t, nc.(*certificateV2).setSignature(b)) + + _, err = caPool.VerifyCertificate(time.Now(), nc) + require.EqualError(t, err, "certificate is in the block list") + caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) @@ -411,7 +497,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { require.NoError(t, err) caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) + b, err = caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) @@ -420,7 +506,17 @@ func TestCertificateV2_VerifyP256(t *testing.T) { }) c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - _, err = caPool.VerifyCertificate(time.Now(), c) + cc, err := caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Reset the blocklist and block the alternate form fingerprint + caPool.ResetCertBlocklist() + caPool.BlocklistFingerprint(cc.fingerprint2) + err = caPool.VerifyCachedCertificate(time.Now(), cc) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + err = caPool.VerifyCachedCertificate(time.Now(), cc) require.NoError(t, err) } diff --git a/cert/cert.go b/cert/cert.go index 9d40e625..01d775e5 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -4,6 +4,8 @@ import ( "fmt" "net/netip" "time" + + "github.com/slackhq/nebula/cert/p256" ) type Version uint8 @@ -110,6 +112,9 @@ type CachedCertificate struct { InvertedGroups map[string]struct{} Fingerprint string signerFingerprint string + + // A place to store a 2nd fingerprint if the certificate could have one, such as with P256 + fingerprint2 string } func (cc *CachedCertificate) String() string { @@ -152,3 +157,31 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific return c, nil } + +// CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates +// CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step. +func CalculateAlternateFingerprint(c Certificate) (string, error) { + if c.Curve() != Curve_P256 { + return "", nil + } + + nc := c.Copy() + b, err := p256.Swap(nc.Signature()) + if err != nil { + return "", err + } + + switch v := nc.(type) { + case *certificateV1: + err = v.setSignature(b) + case *certificateV2: + err = v.setSignature(b) + default: + return "", ErrUnknownVersion + } + + if err != nil { + return "", err + } + return nc.Fingerprint() +} diff --git a/cert/p256/p256.go b/cert/p256/p256.go new file mode 100644 index 00000000..dc609a35 --- /dev/null +++ b/cert/p256/p256.go @@ -0,0 +1,127 @@ +package p256 + +import ( + "crypto/elliptic" + "errors" + "math/big" + + "filippo.io/bigmod" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +var halfN = new(big.Int).Rsh(elliptic.P256().Params().N, 1) +var nMod *bigmod.Modulus + +func init() { + n, err := bigmod.NewModulus(elliptic.P256().Params().N.Bytes()) + if err != nil { + panic(err) + } + nMod = n +} + +func IsNormalized(sig []byte) (bool, error) { + r, s, err := parseSignature(sig) + if err != nil { + return false, err + } + return checkLowS(r, s), nil +} + +func checkLowS(_, s []byte) bool { + bigS := new(big.Int).SetBytes(s) + // Check if S <= (N/2), because we want to include the midpoint in the set of low-s + return bigS.Cmp(halfN) <= 0 +} + +func swap(r, s []byte) ([]byte, []byte, error) { + var err error + bigS, err := bigmod.NewNat().SetBytes(s, nMod) + if err != nil { + return nil, nil, err + } + sNormalized := nMod.Nat().Sub(bigS, nMod) + + result := sNormalized.Bytes(nMod) + for len(result) > 1 && result[0] == 0 { + result = result[1:] + } + + return r, result, nil +} + +func Normalize(sig []byte) ([]byte, error) { + r, s, err := parseSignature(sig) + if err != nil { + return nil, err + } + + if checkLowS(r, s) { + return sig, nil + } + + newR, newS, err := swap(r, s) + if err != nil { + return nil, err + } + + return encodeSignature(newR, newS) +} + +// Swap will change sig between its current form to the opposite high or low form. +func Swap(sig []byte) ([]byte, error) { + r, s, err := parseSignature(sig) + if err != nil { + return nil, err + } + + newR, newS, err := swap(r, s) + if err != nil { + return nil, err + } + + return encodeSignature(newR, newS) +} + +// parseSignature taken exactly from crypto/ecdsa/ecdsa.go +func parseSignature(sig []byte) (r, s []byte, err error) { + var inner cryptobyte.String + input := cryptobyte.String(sig) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(&r) || + !inner.ReadASN1Integer(&s) || + !inner.Empty() { + return nil, nil, errors.New("invalid ASN.1") + } + return r, s, nil +} + +func encodeSignature(r, s []byte) ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + addASN1IntBytes(b, r) + addASN1IntBytes(b, s) + }) + return b.Bytes() +} + +// addASN1IntBytes encodes in ASN.1 a positive integer represented as +// a big-endian byte slice with zero or more leading zeroes. +func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { + for len(bytes) > 0 && bytes[0] == 0 { + bytes = bytes[1:] + } + if len(bytes) == 0 { + b.SetError(errors.New("invalid integer")) + return + } + b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) { + if bytes[0]&0x80 != 0 { + c.AddUint8(0) + } + c.AddBytes(bytes) + }) +} diff --git a/cert/p256/p256_test.go b/cert/p256/p256_test.go new file mode 100644 index 00000000..486a7242 --- /dev/null +++ b/cert/p256/p256_test.go @@ -0,0 +1,28 @@ +package p256 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFlipping(t *testing.T) { + priv, err1 := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err1) + + out, err := ecdsa.SignASN1(rand.Reader, priv, []byte("big chungus")) + require.NoError(t, err) + + r, s, err := parseSignature(out) + require.NoError(t, err) + + r, s1, err := swap(r, s) + require.NoError(t, err) + r, s2, err := swap(r, s1) + require.NoError(t, err) + require.Equal(t, s, s2) + require.NotEqual(t, s, s1) +} diff --git a/cert/pem.go b/cert/pem.go index 8942c23a..84221b22 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -1,12 +1,66 @@ package cert import ( + "bytes" "encoding/pem" + "errors" "fmt" "golang.org/x/crypto/ed25519" ) +var ErrTruncatedPEMBlock = errors.New("truncated PEM block") + +// SplitPEM is a split function for bufio.Scanner that returns each PEM block. +func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Look for the start of a PEM block + start := bytes.Index(data, []byte("-----BEGIN ")) + if start == -1 { + if atEOF && len(bytes.TrimSpace(data)) > 0 { + // Non-whitespace content with no PEM block + return 0, nil, ErrTruncatedPEMBlock + } + if atEOF { + return len(data), nil, nil + } + // Request more data + return 0, nil, nil + } + + // Look for the end marker + endMarkerStart := bytes.Index(data[start:], []byte("-----END ")) + if endMarkerStart == -1 { + if atEOF { + // Incomplete PEM block at EOF + return 0, nil, ErrTruncatedPEMBlock + } + // Need more data to find the end + return 0, nil, nil + } + + // Find the actual end of the END line (after the newline) + endMarkerStart += start + endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n') + var end int + if endLineEnd == -1 { + if atEOF { + // END marker without newline at EOF - take it anyway + end = len(data) + } else { + // Need more data + return 0, nil, nil + } + } else { + end = endMarkerStart + endLineEnd + 1 + } + + // Extract the PEM block + pemBlock := data[start:end] + + // Return the valid PEM block + return end, pemBlock, nil +} + const ( //cert banners CertificateBanner = "NEBULA CERTIFICATE" CertificateV2Banner = "NEBULA CERTIFICATE V2" @@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { return nil, r, ErrInvalidPEMBlock } - var c Certificate - var err error - - switch p.Type { - // Implementations must validate the resulting certificate contains valid information - case CertificateBanner: - c, err = unmarshalCertificateV1(p.Bytes, nil) - case CertificateV2Banner: - c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) - default: - return nil, r, ErrInvalidPEMCertificateBanner - } - + c, err := unmarshalCertificateBlock(p) if err != nil { return nil, r, err } @@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { } +// unmarshalCertificateBlock decodes a single PEM block into a certificate. +// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise. +func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) { + switch block.Type { + // Implementations must validate the resulting certificate contains valid information + case CertificateBanner: + return unmarshalCertificateV1(block.Bytes, nil) + case CertificateV2Banner: + return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519) + default: + return nil, ErrInvalidPEMCertificateBanner + } +} + func marshalCertPublicKeyToPEM(c Certificate) []byte { if c.IsCA() { return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey()) diff --git a/cert/pem_test.go b/cert/pem_test.go index 310c57a3..ff623541 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -1,12 +1,88 @@ package cert import ( + "bufio" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func scanAll(t *testing.T, input string) ([]string, error) { + t.Helper() + scanner := bufio.NewScanner(strings.NewReader(input)) + scanner.Split(SplitPEM) + var blocks []string + for scanner.Scan() { + blocks = append(blocks, scanner.Text()) + } + return blocks, scanner.Err() +} + +func TestSplitPEM_Single(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 1) + require.Equal(t, input, blocks[0]) +} + +func TestSplitPEM_Multiple(t *testing.T) { + block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n" + block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n" + blocks, err := scanAll(t, block1+block2) + require.NoError(t, err) + require.Len(t, blocks, 2) + require.Equal(t, block1, blocks[0]) + require.Equal(t, block2, blocks[1]) +} + +func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) { + input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 2) +} + +func TestSplitPEM_Empty(t *testing.T) { + blocks, err := scanAll(t, "") + require.NoError(t, err) + require.Empty(t, blocks) +} + +func TestSplitPEM_WhitespaceOnly(t *testing.T) { + blocks, err := scanAll(t, " \n\t\n ") + require.NoError(t, err) + require.Empty(t, blocks) +} + +func TestSplitPEM_TrailingGarbage(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage" + blocks, err := scanAll(t, input) + require.ErrorIs(t, err, ErrTruncatedPEMBlock) + require.Len(t, blocks, 1) +} + +func TestSplitPEM_TruncatedBlock(t *testing.T) { + input := "-----BEGIN TEST-----\npartial data with no end" + _, err := scanAll(t, input) + require.ErrorIs(t, err, ErrTruncatedPEMBlock) +} + +func TestSplitPEM_NoEndNewline(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 1) + require.Equal(t, input, blocks[0]) +} + +func TestSplitPEM_GarbageOnly(t *testing.T) { + _, err := scanAll(t, "this is not PEM data") + require.ErrorIs(t, err, ErrTruncatedPEMBlock) +} + func TestUnmarshalCertificateFromPEM(t *testing.T) { goodCert := []byte(` # A good cert diff --git a/cert/sign.go b/cert/sign.go index 3eb08592..fbfffe4e 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -9,6 +9,8 @@ import ( "fmt" "net/netip" "time" + + "github.com/slackhq/nebula/cert/p256" ) // TBSCertificate represents a certificate intended to be signed. @@ -126,6 +128,13 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb return nil, err } + if curve == Curve_P256 { + sig, err = p256.Normalize(sig) + if err != nil { + return nil, err + } + } + err = c.setSignature(sig) if err != nil { return nil, err diff --git a/cert/sign_test.go b/cert/sign_test.go index e6f43cdf..bf4c9c0d 100644 --- a/cert/sign_test.go +++ b/cert/sign_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert/p256" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -89,3 +90,48 @@ func TestCertificateV1_SignP256(t *testing.T) { require.NoError(t, err) assert.NotNil(t, uc) } + +func TestCertificate_SignP256_AlwaysNormalized(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: true, + Curve: Curve_P256, + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) + rawPriv := priv.D.FillBytes(make([]byte, 32)) + + for i := 0; i < 1000; i++ { + if i&1 == 1 { + tbs.Version = Version1 + } else { + tbs.Version = Version2 + } + c, err := tbs.Sign(nil, Curve_P256, rawPriv) + require.NoError(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + normie, err := p256.IsNormalized(c.Signature()) + require.NoError(t, err) + assert.True(t, normie) + } +} diff --git a/cert_test/cert.go b/cert_test/cert.go index 75134316..c3759f12 100644 --- a/cert_test/cert.go +++ b/cert_test/cert.go @@ -163,3 +163,55 @@ func P256Keypair() ([]byte, []byte) { pubkey := privkey.PublicKey() return pubkey.Bytes(), privkey.Bytes() } + +// DummyCert is a minimal cert.Certificate implementation for testing error paths. +type DummyCert struct { + Version_ cert.Version + Curve_ cert.Curve + Groups_ []string + IsCA_ bool + Issuer_ string + Name_ string + Networks_ []netip.Prefix + NotAfter_ time.Time + NotBefore_ time.Time + PublicKey_ []byte + Signature_ []byte + UnsafeNetworks_ []netip.Prefix +} + +func (d *DummyCert) Version() cert.Version { return d.Version_ } +func (d *DummyCert) Curve() cert.Curve { return d.Curve_ } +func (d *DummyCert) Groups() []string { return d.Groups_ } +func (d *DummyCert) IsCA() bool { return d.IsCA_ } +func (d *DummyCert) Issuer() string { return d.Issuer_ } +func (d *DummyCert) Name() string { return d.Name_ } +func (d *DummyCert) Networks() []netip.Prefix { return d.Networks_ } +func (d *DummyCert) NotAfter() time.Time { return d.NotAfter_ } +func (d *DummyCert) NotBefore() time.Time { return d.NotBefore_ } +func (d *DummyCert) PublicKey() []byte { return d.PublicKey_ } +func (d *DummyCert) Signature() []byte { return d.Signature_ } +func (d *DummyCert) UnsafeNetworks() []netip.Prefix { return d.UnsafeNetworks_ } +func (d *DummyCert) Fingerprint() (string, error) { return "", nil } +func (d *DummyCert) CheckSignature(key []byte) bool { return false } +func (d *DummyCert) MarshalForHandshakes() ([]byte, error) { return nil, nil } +func (d *DummyCert) MarshalPEM() ([]byte, error) { return nil, nil } +func (d *DummyCert) MarshalJSON() ([]byte, error) { return nil, nil } +func (d *DummyCert) Marshal() ([]byte, error) { return nil, nil } +func (d *DummyCert) String() string { return "dummy" } +func (d *DummyCert) Copy() cert.Certificate { return d } +func (d *DummyCert) VerifyPrivateKey(c cert.Curve, k []byte) error { return nil } +func (d *DummyCert) Expired(time.Time) bool { return false } +func (d *DummyCert) MarshalPublicKeyPEM() []byte { return nil } +func (d *DummyCert) PublicKeyPEM() []byte { return nil } + +// NewTestCAPool creates a CAPool from the given CA certificates, panicking on error. +func NewTestCAPool(cas ...cert.Certificate) *cert.CAPool { + pool := cert.NewCAPool() + for _, ca := range cas { + if err := pool.AddCA(ca); err != nil { + panic(err) + } + } + return pool +} diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index bea4d1d9..36258dd8 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "os" - "strings" "time" "github.com/slackhq/nebula/cert" @@ -40,21 +39,15 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCACert, err := os.ReadFile(*vf.caPath) + caFile, err := os.Open(*vf.caPath) if err != nil { return fmt.Errorf("error while reading ca: %w", err) } + defer caFile.Close() - caPool := cert.NewCAPool() - for { - rawCACert, err = caPool.AddCAFromPEM(rawCACert) - if err != nil { - return fmt.Errorf("error while adding ca cert to pool: %w", err) - } - - if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { - break - } + caPool, err := cert.NewCAPoolFromPEMReader(caFile) + if err != nil && !errors.Is(err, cert.ErrExpired) { + return fmt.Errorf("error while adding ca cert to pool: %w", err) } rawCert, err := os.ReadFile(*vf.certPath) diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index f555e5f5..1aa5e8e6 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -64,7 +64,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) - require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") + require.ErrorIs(t, err, cert.ErrInvalidPEMBlock) // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) diff --git a/cmd/nebula-service/logs_generic.go b/cmd/nebula-service/logs_generic.go index 3b7cdd1c..cc06b4c5 100644 --- a/cmd/nebula-service/logs_generic.go +++ b/cmd/nebula-service/logs_generic.go @@ -3,8 +3,15 @@ package main -import "github.com/sirupsen/logrus" +import ( + "log/slog" + "os" -func HookLogger(l *logrus.Logger) { - // Do nothing, let the logs flow to stdout/stderr + "github.com/slackhq/nebula/logging" +) + +// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows +// platforms have no special sink to integrate with. +func newPlatformLogger() *slog.Logger { + return logging.NewLogger(os.Stdout) } diff --git a/cmd/nebula-service/logs_windows.go b/cmd/nebula-service/logs_windows.go index af6480ef..ca0a55c5 100644 --- a/cmd/nebula-service/logs_windows.go +++ b/cmd/nebula-service/logs_windows.go @@ -1,54 +1,86 @@ package main import ( - "fmt" - "io/ioutil" - "os" + "context" + "log/slog" + "strings" + "sync" - "github.com/kardianos/service" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/logging" ) -// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer -// logrus output will be discarded -func HookLogger(l *logrus.Logger) { - l.AddHook(newLogHook(logger)) - l.SetOutput(ioutil.Discard) +// newPlatformLogger returns a *slog.Logger that routes every log record +// through the Windows service logger so records end up in the Windows +// Event Log. All the heavy lifting (level management, format swap, +// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler; +// this file only contributes: +// +// - an io.Writer that forwards each formatted line to the service +// logger at the current record's Event Log severity, and +// - a thin severityTag that embeds *logging.Handler and overrides +// only Handle / WithAttrs / WithGroup, so Event Viewer's severity +// column and severity-based filters keep working the way they did +// before the slog migration. +// +// Format (text vs json) is carried by the embedded *logging.Handler, so +// logging.format: json in config still produces JSON lines in Event +// Viewer, same as the pre-slog logrus setup. +func newPlatformLogger() *slog.Logger { + w := &eventLogWriter{} + return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w}) } -type logHook struct { - sl service.Logger +// eventLogWriter forwards slog-formatted lines to the Windows service +// logger at the severity most recently stashed by severityTag.Handle. +// The mutex serializes the stash + inner.Handle + Write cycle per record +// across all concurrent goroutines; slog's builtin text/json handlers +// each hold their own mutex around Write, but that only protects the +// Write call itself, not our stash-then-handle sequence. +type eventLogWriter struct { + mu sync.Mutex + level slog.Level } -func newLogHook(sl service.Logger) *logHook { - return &logHook{sl: sl} -} - -func (h *logHook) Fire(entry *logrus.Entry) error { - line, err := entry.String() - if err != nil { - fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err) - return err - } - - switch entry.Level { - case logrus.PanicLevel: - return h.sl.Error(line) - case logrus.FatalLevel: - return h.sl.Error(line) - case logrus.ErrorLevel: - return h.sl.Error(line) - case logrus.WarnLevel: - return h.sl.Warning(line) - case logrus.InfoLevel: - return h.sl.Info(line) - case logrus.DebugLevel: - return h.sl.Info(line) +func (w *eventLogWriter) Write(p []byte) (int, error) { + line := strings.TrimRight(string(p), "\n") + switch { + case w.level >= slog.LevelError: + return len(p), logger.Error(line) + case w.level >= slog.LevelWarn: + return len(p), logger.Warning(line) default: - return nil + return len(p), logger.Info(line) } } -func (h *logHook) Levels() []logrus.Level { - return logrus.AllLevels +// severityTag embeds *logging.Handler to pick up everything it does for +// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat, +// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup +// so each record's slog.Level is stashed on the writer before formatting +// and so derived handlers stay wrapped as severityTag rather than +// downgrading to bare *logging.Handler. +type severityTag struct { + *logging.Handler + w *eventLogWriter +} + +func (s *severityTag) Handle(ctx context.Context, r slog.Record) error { + s.w.mu.Lock() + defer s.w.mu.Unlock() + s.w.level = r.Level + return s.Handler.Handle(ctx, r) +} + +func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return s + } + return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w} +} + +func (s *severityTag) WithGroup(name string) slog.Handler { + if name == "" { + return s + } + return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w} } diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 9a17b947..19fb3a9f 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -7,9 +7,9 @@ import ( "runtime/debug" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/util" ) @@ -50,9 +50,14 @@ func main() { os.Exit(0) } + l := logging.NewLogger(os.Stdout) + if *serviceFlag != "" { - doService(configPath, configTest, Build, serviceFlag) - os.Exit(1) + if err := doService(configPath, configTest, Build, serviceFlag); err != nil { + l.Error("Service command failed", "error", err) + os.Exit(1) + } + return } if *configPath == "" { @@ -61,9 +66,6 @@ func main() { os.Exit(1) } - l := logrus.New() - l.Out = os.Stdout - c := config.NewC(l) err := c.Load(*configPath) if err != nil { @@ -71,6 +73,16 @@ func main() { os.Exit(1) } + if err := logging.ApplyConfig(l, c); err != nil { + fmt.Printf("failed to apply logging config: %s", err) + os.Exit(1) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) @@ -78,8 +90,20 @@ func main() { } if !*configTest { - ctrl.Start() - ctrl.ShutdownBlock() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() + + if err := wait(); err != nil { + l.Error("Nebula stopped due to fatal error", "error", err) + os.Exit(2) + } + + l.Info("Goodbye") } os.Exit(0) diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index a54fb0f3..6551ceb4 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -7,9 +7,9 @@ import ( "path/filepath" "github.com/kardianos/service" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" ) var logger service.Logger @@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error { // Start should not block. logger.Info("Nebula service starting.") - l := logrus.New() - HookLogger(l) + l := newPlatformLogger() c := config.NewC(l) err := c.Load(*p.configPath) @@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error { return fmt.Errorf("failed to load config: %s", err) } + if err := logging.ApplyConfig(l, c); err != nil { + return fmt.Errorf("failed to apply logging config: %s", err) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + p.control, err = nebula.Main(c, *p.configTest, Build, l, nil) if err != nil { return err @@ -57,11 +65,11 @@ func fileExists(filename string) bool { return true } -func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { +func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error { if *configPath == "" { ex, err := os.Executable() if err != nil { - panic(err) + return err } *configPath = filepath.Dir(ex) + "/config.yaml" if !fileExists(*configPath) { @@ -85,16 +93,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag * // Here are what the different loggers are doing: // - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr // - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log) - // - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use + // - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger s, err := service.New(prg, svcConfig) if err != nil { - log.Fatal(err) + return err } errs := make(chan error, 5) logger, err = s.Logger(errs) if err != nil { - log.Fatal(err) + return err } go func() { @@ -109,18 +117,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag * switch *serviceFlag { case "run": - err = s.Run() - if err != nil { + if err := s.Run(); err != nil { // Route any errors to the system logger logger.Error(err) } default: - err := service.Control(s, *serviceFlag) - if err != nil { + if err := service.Control(s, *serviceFlag); err != nil { log.Printf("Valid actions: %q\n", service.ControlAction) - log.Fatal(err) + return err } - return } + return nil } diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index ffdc15bf..d7f0de93 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -7,9 +7,9 @@ import ( "runtime/debug" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/util" ) @@ -55,8 +55,7 @@ func main() { os.Exit(1) } - l := logrus.New() - l.Out = os.Stdout + l := logging.NewLogger(os.Stdout) c := config.NewC(l) err := c.Load(*configPath) @@ -65,6 +64,16 @@ func main() { os.Exit(1) } + if err := logging.ApplyConfig(l, c); err != nil { + fmt.Printf("failed to apply logging config: %s", err) + os.Exit(1) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) @@ -72,9 +81,21 @@ func main() { } if !*configTest { - ctrl.Start() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() notifyReady(l) - ctrl.ShutdownBlock() + + if err := wait(); err != nil { + l.Error("Nebula stopped due to fatal error", "error", err) + os.Exit(2) + } + + l.Info("Goodbye") } os.Exit(0) diff --git a/cmd/nebula/notify_linux.go b/cmd/nebula/notify_linux.go index 8c3dca55..965986a9 100644 --- a/cmd/nebula/notify_linux.go +++ b/cmd/nebula/notify_linux.go @@ -1,11 +1,10 @@ package main import ( + "log/slog" "net" "os" "time" - - "github.com/sirupsen/logrus" ) // SdNotifyReady tells systemd the service is ready and dependent services can now be started @@ -13,30 +12,30 @@ import ( // https://www.freedesktop.org/software/systemd/man/systemd.service.html const SdNotifyReady = "READY=1" -func notifyReady(l *logrus.Logger) { +func notifyReady(l *slog.Logger) { sockName := os.Getenv("NOTIFY_SOCKET") if sockName == "" { - l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal") + l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal") return } conn, err := net.DialTimeout("unixgram", sockName, time.Second) if err != nil { - l.WithError(err).Error("failed to connect to systemd notification socket") + l.Error("failed to connect to systemd notification socket", "error", err) return } defer conn.Close() err = conn.SetWriteDeadline(time.Now().Add(time.Second)) if err != nil { - l.WithError(err).Error("failed to set the write deadline for the systemd notification socket") + l.Error("failed to set the write deadline for the systemd notification socket", "error", err) return } if _, err = conn.Write([]byte(SdNotifyReady)); err != nil { - l.WithError(err).Error("failed to signal the systemd notification socket") + l.Error("failed to signal the systemd notification socket", "error", err) return } - l.Debugln("notified systemd the service is ready") + l.Debug("notified systemd the service is ready") } diff --git a/cmd/nebula/notify_notlinux.go b/cmd/nebula/notify_notlinux.go index e7758e09..48cfe949 100644 --- a/cmd/nebula/notify_notlinux.go +++ b/cmd/nebula/notify_notlinux.go @@ -3,8 +3,8 @@ package main -import "github.com/sirupsen/logrus" +import "log/slog" -func notifyReady(_ *logrus.Logger) { +func notifyReady(_ *slog.Logger) { // No init service to notify } diff --git a/config/config.go b/config/config.go index 0d1be128..5bf994a1 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log/slog" "math" "os" "os/signal" @@ -16,7 +17,6 @@ import ( "time" "dario.cat/mergo" - "github.com/sirupsen/logrus" "go.yaml.in/yaml/v3" ) @@ -26,11 +26,11 @@ type C struct { Settings map[string]any oldSettings map[string]any callbacks []func(*C) - l *logrus.Logger + l *slog.Logger reloadLock sync.Mutex } -func NewC(l *logrus.Logger) *C { +func NewC(l *slog.Logger) *C { return &C{ Settings: make(map[string]any), l: l, @@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool { newVals, err := yaml.Marshal(nv) if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") + c.l.Error("Error while marshaling new config", + "config_path", k, + "error", err, + ) } oldVals, err := yaml.Marshal(ov) if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") + c.l.Error("Error while marshaling old config", + "config_path", k, + "error", err, + ) } return string(newVals) != string(oldVals) @@ -154,7 +160,10 @@ func (c *C) ReloadConfig() { err := c.Load(c.path) if err != nil { - c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") + c.l.Error("Error occurred while reloading config", + "config_path", c.path, + "error", err, + ) return } diff --git a/connection_manager.go b/connection_manager.go index 4c2f26ef..e7fc04cd 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -5,13 +5,13 @@ import ( "context" "encoding/binary" "fmt" + "log/slog" "net/netip" "sync" "sync/atomic" "time" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -47,10 +47,10 @@ type connectionManager struct { metricsTxPunchy metrics.Counter - l *logrus.Logger + l *slog.Logger } -func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { +func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { cm := &connectionManager{ hostMap: hm, l: l, @@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) { old := cm.getInactivityTimeout() cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute))) if !initial { - cm.l.WithField("oldDuration", old). - WithField("newDuration", cm.getInactivityTimeout()). - Info("Inactivity timeout has changed") + cm.l.Info("Inactivity timeout has changed", + "oldDuration", old, + "newDuration", cm.getInactivityTimeout(), + ) } } @@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) { old := cm.dropInactive.Load() cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false)) if !initial { - cm.l.WithField("oldBool", old). - WithField("newBool", cm.dropInactive.Load()). - Info("Drop inactive setting has changed") + cm.l.Info("Drop inactive setting has changed", + "oldBool", old, + "newBool", cm.dropInactive.Load(), + ) } } } @@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo var err error index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { - cm.l.WithError(err).Error("failed to migrate relay to new hostinfo") + cm.l.Error("failed to migrate relay to new hostinfo", "error", err) continue } switch r.Type { @@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo msg, err := req.Marshal() if err != nil { - cm.l.WithError(err).Error("failed to marshal Control message to migrate relay") + cm.l.Error("failed to marshal Control message to migrate relay", "error", err) } else { cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) - cm.l.WithFields(logrus.Fields{ - "relayFrom": req.RelayFromAddr, - "relayTo": req.RelayToAddr, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnAddrs": newhostinfo.vpnAddrs}). - Info("send CreateRelayRequest") + cm.l.Info("send CreateRelayRequest", + "relayFrom", req.RelayFromAddr, + "relayTo", req.RelayToAddr, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddrs", newhostinfo.vpnAddrs, + ) } } } @@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim hostinfo := cm.hostMap.Indexes[localIndex] if hostinfo == nil { - cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap") + cm.l.Debug("Not found in hostmap", "localIndex", localIndex) return doNothing, nil, nil } @@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim // A hostinfo is determined alive if there is incoming traffic if inTraffic { decision := doNothing - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). - Debug("Tunnel status") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Tunnel status", + "tunnelCheck", m{"state": "alive", "method": "passive"}, + ) } hostinfo.pendingDeletion.Store(false) @@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if hostinfo.pendingDeletion.Load() { // We have already sent a test packet and nothing was returned, this hostinfo is dead - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "dead", "method": "active"}). - Info("Tunnel status") + hostinfo.logger(cm.l).Info("Tunnel status", + "tunnelCheck", m{"state": "dead", "method": "active"}, + ) return deleteTunnel, hostinfo, nil } @@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim inactiveFor, isInactive := cm.isInactive(hostinfo, now) if isInactive { // Tunnel is inactive, tear it down - hostinfo.logger(cm.l). - WithField("inactiveDuration", inactiveFor). - WithField("primary", mainHostInfo). - Info("Dropping tunnel due to inactivity") + hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity", + "inactiveDuration", inactiveFor, + "primary", mainHostInfo, + ) return closeTunnel, hostinfo, primary } @@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim cm.sendPunch(hostinfo) } - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "testing", "method": "active"}). - Debug("Tunnel status") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Tunnel status", + "tunnelCheck", m{"state": "testing", "method": "active"}, + ) } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues decision = sendTestPacket } else { - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l).Debugf("Hostinfo sadness") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Hostinfo sadness") } } @@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI return false //cert is still valid! yay! } else if err == cert.ErrBlockListed { //avoiding errors.Is for speed // Block listed certificates should always be disconnected - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is blocked, tearing down the tunnel") + hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel", + "error", err, + "fingerprint", remoteCert.Fingerprint, + ) return true } else if cm.intf.disconnectInvalid.Load() { - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is no longer valid, tearing down the tunnel") + hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel", + "error", err, + "fingerprint", remoteCert.Fingerprint, + ) return true } else { //if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open @@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { curCrtVersion := curCrt.Version() myCrt := cs.getCertificate(curCrtVersion) if myCrt == nil { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("version", curCrtVersion). - WithField("reason", "local certificate removed"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "version", curCrtVersion, + "reason", "local certificate removed", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } @@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() { // if our certificate version is less than theirs, and we have a matching version available, rehandshake? if cs.getCertificate(peerCrt.Certificate.Version()) != nil { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("version", curCrtVersion). - WithField("peerVersion", peerCrt.Certificate.Version()). - WithField("reason", "local certificate version lower than peer, attempting to correct"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "version", curCrtVersion, + "peerVersion", peerCrt.Certificate.Version(), + "reason", "local certificate version lower than peer, attempting to correct", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) { hh.initiatingVersionOverride = peerCrt.Certificate.Version() }) @@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { } } if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "local certificate is not current"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "reason", "local certificate is not current", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } if curCrtVersion < cs.initiatingVersion { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "current cert version < pki.initiatingVersion"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "reason", "current cert version < pki.initiatingVersion", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return diff --git a/connection_manager_test.go b/connection_manager_test.go index 647dd72b..7dc08a45 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" - "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/overlaytest" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" @@ -46,13 +46,13 @@ func Test_NewConnectionManagerTest(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -63,9 +63,9 @@ func Test_NewConnectionManagerTest(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) @@ -79,7 +79,6 @@ func Test_NewConnectionManagerTest(t *testing.T) { } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, - H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -129,13 +128,13 @@ func Test_NewConnectionManagerTest2(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -146,9 +145,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) @@ -162,7 +161,6 @@ func Test_NewConnectionManagerTest2(t *testing.T) { } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, - H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -214,13 +212,13 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -231,12 +229,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) conf.Settings["tunnels"] = map[string]any{ "drop_inactive": true, } - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) assert.True(t, nc.dropInactive.Load()) nc.intf = ifce @@ -248,7 +246,6 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, - H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -339,15 +336,15 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ - privateKey: []byte{}, - v1Cert: &dummyCert{}, - v1HandshakeBytes: []byte{}, + privateKey: []byte{}, + v1Cert: &dummyCert{}, + v1Credential: nil, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -360,9 +357,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.disconnectInvalid.Store(true) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce ifce.connectionManager = nc @@ -371,7 +368,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ConnectionState: &ConnectionState{ myCert: &dummyCert{}, peerCert: cachedPeerCert, - H: &noise.HandshakeState{}, }, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) diff --git a/connection_state.go b/connection_state.go index db885d42..47e23b5a 100644 --- a/connection_state.go +++ b/connection_state.go @@ -1,16 +1,12 @@ package nebula import ( - "crypto/rand" "encoding/json" - "fmt" "sync" "sync/atomic" - "github.com/flynn/noise" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/handshake" ) const ReplayWindow = 1024 @@ -18,7 +14,6 @@ const ReplayWindow = 1024 type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState - H *noise.HandshakeState myCert cert.Certificate peerCert *cert.CachedCertificate initiator bool @@ -27,55 +22,24 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { - var dhFunc noise.DHFunc - switch crt.Curve() { - case cert.Curve_CURVE25519: - dhFunc = noise.DH25519 - case cert.Curve_P256: - if cs.pkcs11Backed { - dhFunc = noiseutil.DHP256PKCS11 - } else { - dhFunc = noiseutil.DHP256 - } - default: - return nil, fmt.Errorf("invalid curve: %s", crt.Curve()) - } - - var ncs noise.CipherSuite - if cs.cipher == "chachapoly" { - ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) - } else { - ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) - } - - static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} - hs, err := noise.NewHandshakeState(noise.Config{ - CipherSuite: ncs, - Random: rand.Reader, - Pattern: pattern, - Initiator: initiator, - StaticKeypair: static, - //NOTE: These should come from CertState (pki.go) when we finally implement it - PresharedKey: []byte{}, - PresharedKeyPlacement: 0, - }) - if err != nil { - return nil, fmt.Errorf("NewConnectionState: %s", err) - } - - // The queue and ready params prevent a counter race that would happen when - // sending stored packets and simultaneously accepting new traffic. +// newConnectionStateFromResult builds a fully-populated ConnectionState from a +// completed handshake.Result. It seeds messageCounter and the replay window so +// that the post-handshake message indices already used on the wire don't count +// as missed traffic in the data plane. +func newConnectionStateFromResult(r *handshake.Result) *ConnectionState { ci := &ConnectionState{ - H: hs, - initiator: initiator, + myCert: r.MyCert, + initiator: r.Initiator, + peerCert: r.RemoteCert, + eKey: NewNebulaCipherState(r.EKey), + dKey: NewNebulaCipherState(r.DKey), window: NewBits(ReplayWindow), - myCert: crt, } - // always start the counter from 2, as packet 1 and packet 2 are handshake packets. - ci.messageCounter.Add(2) - - return ci, nil + ci.messageCounter.Add(r.MessageIndex) + for i := uint64(1); i <= r.MessageIndex; i++ { + ci.window.Update(nil, i) + } + return ci } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { diff --git a/connection_state_test.go b/connection_state_test.go new file mode 100644 index 00000000..dea60d39 --- /dev/null +++ b/connection_state_test.go @@ -0,0 +1,114 @@ +package nebula + +import ( + "net/netip" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + ct "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/handshake" + "github.com/slackhq/nebula/header" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// runTestHandshake runs a complete IX handshake between two freshly-built +// peers and returns the initiator and responder Results. Used to produce +// real cipher states for tests that need to exercise post-handshake glue. +func runTestHandshake(t *testing.T) (initR, respR *handshake.Result) { + t.Helper() + + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + makeCreds := func(name string, networks []netip.Prefix) handshake.GetCredentialFunc { + c, _, rawKey, _ := ct.NewTestCert( + cert.Version2, cert.Curve_CURVE25519, ca, caKey, + name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil, + ) + priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawKey) + require.NoError(t, err) + hsBytes, err := c.MarshalForHandshakes() + require.NoError(t, err) + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + cred := handshake.NewCredential(c, hsBytes, priv, ncs) + return func(v cert.Version) *handshake.Credential { + if v == cert.Version2 { + return cred + } + return nil + } + } + + verifier := func(c cert.Certificate) (*cert.CachedCertificate, error) { + return caPool.VerifyCertificate(time.Now(), c) + } + + initCreds := makeCreds("initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCreds := makeCreds("responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initM, err := handshake.NewMachine( + cert.Version2, initCreds, verifier, + func() (uint32, error) { return 1000, nil }, + true, header.HandshakeIXPSK0, + ) + require.NoError(t, err) + + respM, err := handshake.NewMachine( + cert.Version2, respCreds, verifier, + func() (uint32, error) { return 2000, nil }, + false, header.HandshakeIXPSK0, + ) + require.NoError(t, err) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + resp, respR, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + require.NotNil(t, respR) + + _, initR, err = initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, initR) + + return initR, respR +} + +func TestNewConnectionStateFromResult(t *testing.T) { + initR, respR := runTestHandshake(t) + + t.Run("initiator", func(t *testing.T) { + ci := newConnectionStateFromResult(initR) + assert.True(t, ci.initiator) + assert.Equal(t, initR.MyCert, ci.myCert) + assert.Equal(t, initR.RemoteCert, ci.peerCert) + assert.NotNil(t, ci.eKey) + assert.NotNil(t, ci.dKey) + + // IX has 2 handshake messages; the next data-plane send is counter=3. + assert.Equal(t, uint64(2), ci.messageCounter.Load(), + "messageCounter must equal Result.MessageIndex so the next send is N+1") + + // Both handshake counters must be marked seen so they don't appear lost. + // Check returns false if an index has already been recorded. + assert.False(t, ci.window.Check(nil, 1), "counter 1 must already be seen") + assert.False(t, ci.window.Check(nil, 2), "counter 2 must already be seen") + // Counter 3 is the next data-plane message and must NOT be pre-marked. + assert.True(t, ci.window.Check(nil, 3), "counter 3 must not be pre-seeded") + }) + + t.Run("responder", func(t *testing.T) { + ci := newConnectionStateFromResult(respR) + assert.False(t, ci.initiator) + assert.Equal(t, respR.MyCert, ci.myCert) + assert.Equal(t, respR.RemoteCert, ci.peerCert) + assert.NotNil(t, ci.eKey) + assert.NotNil(t, ci.dKey) + assert.Equal(t, uint64(2), ci.messageCounter.Load()) + }) +} diff --git a/control.go b/control.go index f8567b50..ef58988b 100644 --- a/control.go +++ b/control.go @@ -2,17 +2,33 @@ package nebula import ( "context" + "errors" + "log/slog" "net/netip" "os" "os/signal" + "sync" "syscall" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" ) +type RunState int + +const ( + StateUnknown RunState = iota + StateReady + StateStarted + StateStopping + StateStopped +) + +var ErrAlreadyStarted = errors.New("nebula is already started") +var ErrAlreadyStopped = errors.New("nebula cannot be restarted") +var ErrUnknownState = errors.New("nebula state is invalid") + // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc @@ -26,8 +42,11 @@ type controlHostLister interface { } type Control struct { + stateLock sync.Mutex + state RunState + f *Interface - l *logrus.Logger + l *slog.Logger ctx context.Context cancel context.CancelFunc sshStart func() @@ -49,10 +68,31 @@ type ControlHostInfo struct { CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } -// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() -func (c *Control) Start() { +// Start actually runs nebula, this is a nonblocking call. +// The returned function blocks until nebula has fully stopped and returns the +// first fatal reader error (if any). A nil error means nebula shut down +// gracefully; a non-nil error means a reader hit an unexpected failure that +// triggered the shutdown. +func (c *Control) Start() (func() error, error) { + c.stateLock.Lock() + defer c.stateLock.Unlock() + switch c.state { + case StateReady: + //yay! + case StateStopped, StateStopping: + return nil, ErrAlreadyStopped + case StateStarted: + return nil, ErrAlreadyStarted + default: + return nil, ErrUnknownState + } + // Activate the interface - c.f.activate() + err := c.f.activate() + if err != nil { + c.state = StateStopped + return nil, err + } // Call all the delayed funcs that waited patiently for the interface to be created. if c.sshStart != nil { @@ -71,25 +111,51 @@ func (c *Control) Start() { c.lighthouseStart() } + c.f.triggerShutdown = c.Stop + // Start reading packets. - c.f.run() + out, err := c.f.run() + if err != nil { + c.state = StateStopped + return nil, err + } + c.state = StateStarted + return out, nil +} + +func (c *Control) State() RunState { + c.stateLock.Lock() + defer c.stateLock.Unlock() + return c.state } func (c *Control) Context() context.Context { return c.ctx } -// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete +// Stop is a non-blocking call that signals nebula to close all tunnels and shut down func (c *Control) Stop() { + c.stateLock.Lock() + if c.state != StateStarted { + c.stateLock.Unlock() + // We are stopping or stopped already + return + } + + c.state = StateStopping + c.stateLock.Unlock() + // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. c.cancel() c.CloseAllTunnels(false) if err := c.f.Close(); err != nil { - c.l.WithError(err).Error("Close interface failed") + c.l.Error("Close interface failed", "error", err) } - c.l.Info("Goodbye") + c.stateLock.Lock() + c.state = StateStopped + c.stateLock.Unlock() } // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled @@ -100,7 +166,7 @@ func (c *Control) ShutdownBlock() { rawSig := <-sigChan sig := rawSig.String() - c.l.WithField("signal", sig).Info("Caught signal, shutting down") + c.l.Info("Caught signal, shutting down", "signal", sig) c.Stop() } @@ -237,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) - c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). - Debug("Sending close tunnel message") + c.l.Debug("Sending close tunnel message", + "vpnAddrs", h.vpnAddrs, + "udpAddr", h.remote, + ) closed++ } diff --git a/control_test.go b/control_test.go index e8a5d312..5e381c46 100644 --- a/control_test.go +++ b/control_test.go @@ -6,7 +6,6 @@ import ( "reflect" "testing" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" @@ -79,10 +78,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, &Interface{}) c := Control{ + state: StateReady, f: &Interface{ hostMap: hm, }, - l: logrus.New(), + l: test.NewLogger(), } thi := c.GetHostInfoByVpnAddr(vpnIp, false) diff --git a/control_tester.go b/control_tester.go index 7403a745..728ac649 100644 --- a/control_tester.go +++ b/control_tester.go @@ -1,13 +1,10 @@ //go:build e2e_testing -// +build e2e_testing package nebula import ( "net/netip" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" @@ -23,7 +20,9 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message panic(err) } pipeTo.InjectUDPPacket(p) - if h.Type == msgType && h.Subtype == subType { + match := h.Type == msgType && h.Subtype == subType + p.Release() + if match { return } } @@ -39,7 +38,9 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, panic(err) } pipeTo.InjectUDPPacket(p) - if h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType { + match := h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType + p.Release() + if match { return } } @@ -91,65 +92,15 @@ func (c *Control) GetTunTxChan() <-chan []byte { return c.f.inside.(*overlay.TestTun).TxPackets } -// InjectUDPPacket will inject a packet into the udp side of nebula +// InjectUDPPacket injects a packet into the udp side. We copy internally so the caller keeps ownership of p. +// The copy comes from the freelist so steady-state alloc is zero. func (c *Control) InjectUDPPacket(p *udp.Packet) { - c.f.outside.(*udp.TesterConn).Send(p) + c.f.outside.(*udp.TesterConn).Send(p.Copy()) } -// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { - serialize := make([]gopacket.SerializableLayer, 0) - var netLayer gopacket.NetworkLayer - if toAddr.Is6() { - if !fromAddr.Is6() { - panic("Cant send ipv6 to ipv4") - } - ip := &layers.IPv6{ - Version: 6, - NextHeader: layers.IPProtocolUDP, - SrcIP: fromAddr.Unmap().AsSlice(), - DstIP: toAddr.Unmap().AsSlice(), - } - serialize = append(serialize, ip) - netLayer = ip - } else { - if !fromAddr.Is4() { - panic("Cant send ipv4 to ipv6") - } - - ip := &layers.IPv4{ - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - SrcIP: fromAddr.Unmap().AsSlice(), - DstIP: toAddr.Unmap().AsSlice(), - } - serialize = append(serialize, ip) - netLayer = ip - } - - udp := layers.UDP{ - SrcPort: layers.UDPPort(fromPort), - DstPort: layers.UDPPort(toPort), - } - err := udp.SetNetworkLayerForChecksum(netLayer) - if err != nil { - panic(err) - } - - buffer := gopacket.NewSerializeBuffer() - opt := gopacket.SerializeOptions{ - ComputeChecksums: true, - FixLengths: true, - } - - serialize = append(serialize, &udp, gopacket.Payload(data)) - err = gopacket.SerializeLayers(buffer, opt, serialize...) - if err != nil { - panic(err) - } - - c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) +// InjectTunPacket pushes an IP packet onto the tun interface. +func (c *Control) InjectTunPacket(packet []byte) { + c.f.inside.(*overlay.TestTun).Send(packet) } func (c *Control) GetVpnAddrs() []netip.Addr { diff --git a/dist/wireshark/nebula.lua b/dist/wireshark/nebula.lua index ddc808f9..d17dc7a0 100644 --- a/dist/wireshark/nebula.lua +++ b/dist/wireshark/nebula.lua @@ -84,30 +84,24 @@ end function nebula.prefs_changed() if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then - -- Nothing changed, bail return end - -- Remove our old dissector + -- Remove all existing registrations DissectorTable.get("udp.port"):remove_all(nebula) - if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then - default_settings.all_port = nebula.prefs.all_ports - + if nebula.prefs.all_ports then + -- Register on every port for hole punch capture for i=0, 65535 do DissectorTable.get("udp.port"):add(i, nebula) end - - -- no need to establish again on specific ports - return + else + -- Register on the configured port only + DissectorTable.get("udp.port"):add(nebula.prefs.port, nebula) end - - if default_settings.all_ports ~= nebula.prefs.all_ports then - -- Add our new port dissector - default_settings.port = nebula.prefs.port - DissectorTable.get("udp.port"):add(default_settings.port, nebula) - end + default_settings.all_ports = nebula.prefs.all_ports + default_settings.port = nebula.prefs.port end DissectorTable.get("udp.port"):add(default_settings.port, nebula) diff --git a/dns_server.go b/dns_server.go index 73576546..ff1369ab 100644 --- a/dns_server.go +++ b/dns_server.go @@ -1,63 +1,249 @@ package nebula import ( + "context" "fmt" + "log/slog" "net" "net/netip" "strconv" "strings" "sync" + "sync/atomic" "github.com/gaissmai/bart" "github.com/miekg/dns" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) -// This whole thing should be rewritten to use context - -var dnsR *dnsRecords -var dnsServer *dns.Server -var dnsAddr string - -type dnsRecords struct { +type dnsServer struct { sync.RWMutex - l *logrus.Logger + l *slog.Logger + ctx context.Context dnsMap4 map[string]netip.Addr dnsMap6 map[string]netip.Addr hostMap *HostMap myVpnAddrsTable *bart.Lite + + mux *dns.ServeMux + + // enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`. + // Start, Add, and reload consult it so callers don't need to know the + // gating rules. When it toggles off via reload, accumulated records are + // cleared so a later re-enable starts with a fresh map populated from + // new handshakes. + enabled atomic.Bool + + serverMu sync.Mutex + server *dns.Server + // started is closed once `server` has finished binding (or after + // ListenAndServe returns on a bind failure). Stop waits on it before + // calling Shutdown to avoid the miekg/dns "server not started" race + // where a Shutdown that arrives before bind completes is silently + // ignored, leaving the listener running forever. + started chan struct{} + addr string } -func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { - return &dnsRecords{ +// newDnsServerFromConfig builds a dnsServer, applies the initial config, and +// registers a reload callback. The reload callback is registered before the +// initial config is applied, so a SIGHUP can later enable, fix, or disable +// DNS even if the initial application failed. +// +// The dnsServer internally gates on `lighthouse.serve_dns && +// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally, +// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel +// watcher that tears the listener down on nebula shutdown. The returned +// pointer is always non-nil, even on error. +func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { + ds := &dnsServer{ l: l, + ctx: ctx, dnsMap4: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr), hostMap: hostMap, myVpnAddrsTable: cs.myVpnAddrsTable, } + ds.mux = dns.NewServeMux() + ds.mux.HandleFunc(".", ds.handleDnsRequest) + + c.RegisterReloadCallback(func(c *config.C) { + if err := ds.reload(c, false); err != nil { + ds.l.Error("Failed to reload DNS responder from config", "error", err) + } + }) + + if err := ds.reload(c, true); err != nil { + return ds, err + } + return ds, nil } -func (d *dnsRecords) Query(q uint16, data string) netip.Addr { +// reload applies the latest config and reconciles the running state with it: +// - enabled toggled on -> spawn a runner +// - enabled toggled off -> stop the runner +// - listen address changed (while running) -> restart on the new address +// - everything else -> no-op +// +// On the initial call it only records configuration; Control.Start is what +// launches the first runner via dnsStart. +func (d *dnsServer) reload(c *config.C, initial bool) error { + wantsDns := c.GetBool("lighthouse.serve_dns", false) + amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) + enabled := wantsDns && amLighthouse + newAddr := getDnsServerAddr(c) + + d.serverMu.Lock() + running := d.server + runningStarted := d.started + sameAddr := d.addr == newAddr + d.addr = newAddr + d.enabled.Store(enabled) + d.serverMu.Unlock() + + if initial { + if wantsDns && !amLighthouse { + d.l.Warn("DNS server refusing to run because this host is not a lighthouse.") + } + return nil + } + + if !enabled { + if running != nil { + d.Stop() + } + // Drop any records that accumulated while enabled; a later re-enable + // will repopulate from fresh handshakes. + d.clearRecords() + return nil + } + + if running == nil { + // Was disabled (or never started); bring it up now. + go d.Start() + return nil + } + + if sameAddr { + return nil + } + + d.shutdownServer(running, runningStarted, "reload") + // Old Start goroutine has now exited; bring up a fresh listener on the + // new address. + go d.Start() + return nil +} + +// shutdownServer waits for the server to finish binding (so Shutdown actually +// stops it rather than no-oping) and then shuts it down. +func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) { + if srv == nil { + return + } + if started != nil { + <-started + } + if err := srv.Shutdown(); err != nil { + d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err) + } +} + +// Start binds and serves the DNS responder. Blocks until Stop is called or +// the listener errors. Safe to call when DNS is disabled (returns +// immediately). This is what Control.dnsStart points at. +// +// Must be invoked after the tun device is active so that lighthouse.dns.host +// may bind to a nebula IP. +func (d *dnsServer) Start() { + if !d.enabled.Load() { + return + } + + started := make(chan struct{}) + d.serverMu.Lock() + if d.ctx.Err() != nil { + d.serverMu.Unlock() + return + } + addr := d.addr + server := &dns.Server{ + Addr: addr, + Net: "udp", + Handler: d.mux, + NotifyStartedFunc: func() { close(started) }, + } + d.server = server + d.started = started + d.serverMu.Unlock() + + // Per-invocation ctx watcher. Exits when Start does, so we don't leak a + // watcher per reload-driven restart. + done := make(chan struct{}) + go func() { + select { + case <-d.ctx.Done(): + d.shutdownServer(server, started, "shutdown") + case <-done: + } + }() + + d.l.Info("Starting DNS responder", "dnsListener", addr) + err := server.ListenAndServe() + close(done) + + // If the listener never bound (bind error) NotifyStartedFunc never fires, + // so close started here to release any Stop caller waiting on it. + select { + case <-started: + default: + close(started) + } + + if err != nil { + d.l.Warn("Failed to run the DNS responder", "error", err) + } +} + +// Stop shuts down the active server, if any. Idempotent. +func (d *dnsServer) Stop() { + d.serverMu.Lock() + srv := d.server + started := d.started + d.server = nil + d.started = nil + d.serverMu.Unlock() + d.shutdownServer(srv, started, "stop") +} + +// Query returns the address for the given name and query type. The second +// return value reports whether the name is known at all (in either A or AAAA), +// which lets callers distinguish NODATA from NXDOMAIN. +func (d *dnsServer) Query(q uint16, data string) (netip.Addr, bool) { data = strings.ToLower(data) d.RLock() defer d.RUnlock() + addr4, haveV4 := d.dnsMap4[data] + addr6, haveV6 := d.dnsMap6[data] + nameExists := haveV4 || haveV6 switch q { case dns.TypeA: - if r, ok := d.dnsMap4[data]; ok { - return r + if haveV4 { + return addr4, nameExists } case dns.TypeAAAA: - if r, ok := d.dnsMap6[data]; ok { - return r + if haveV6 { + return addr6, nameExists } } - return netip.Addr{} + return netip.Addr{}, nameExists } -func (d *dnsRecords) QueryCert(data string) string { +func (d *dnsServer) QueryCert(data string) string { + if len(data) < 2 { + return "" + } ip, err := netip.ParseAddr(data[:len(data)-1]) if err != nil { return "" @@ -80,8 +266,19 @@ func (d *dnsRecords) QueryCert(data string) string { return string(b) } +// clearRecords drops all DNS records. +func (d *dnsServer) clearRecords() { + d.Lock() + defer d.Unlock() + clear(d.dnsMap4) + clear(d.dnsMap6) +} + // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` -func (d *dnsRecords) Add(host string, addresses []netip.Addr) { +func (d *dnsServer) Add(host string, addresses []netip.Addr) { + if !d.enabled.Load() { + return + } host = strings.ToLower(host) d.Lock() defer d.Unlock() @@ -101,7 +298,7 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) { } } -func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { +func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { a, _, _ := net.SplitHostPort(addr) b, err := netip.ParseAddr(a) if err != nil { @@ -116,13 +313,24 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { return d.myVpnAddrsTable.Contains(b) } -func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { +func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { + debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug) + // Per RFC 2308 §2.2, a name that exists but has no record of the requested + // type must be answered with NOERROR and an empty answer section (NODATA), + // not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not + // exist at all. + anyNameExists := false for _, q := range m.Question { switch q.Qtype { case dns.TypeA, dns.TypeAAAA: qType := dns.TypeToString[q.Qtype] - d.l.Debugf("Query for %s %s", qType, q.Name) - ip := d.Query(q.Qtype, q.Name) + if debugEnabled { + d.l.Debug("DNS query", "type", qType, "name", q.Name) + } + ip, nameExists := d.Query(q.Qtype, q.Name) + if nameExists { + anyNameExists = true + } if ip.IsValid() { rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { @@ -134,7 +342,9 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } - d.l.Debugf("Query for TXT %s", q.Name) + if debugEnabled { + d.l.Debug("DNS query", "type", "TXT", "name", q.Name) + } ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) @@ -145,12 +355,12 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { } } - if len(m.Answer) == 0 { + if len(m.Answer) == 0 && !anyNameExists { m.Rcode = dns.RcodeNameError } } -func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { +func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false @@ -163,21 +373,6 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(l, cs, hostMap) - - // attach request handler func - dns.HandleFunc(".", dnsR.handleDnsRequest) - - c.RegisterReloadCallback(func(c *config.C) { - reloadDns(l, c) - }) - - return func() { - startDns(l, c) - } -} - func getDnsServerAddr(c *config.C) string { dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", "")) // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve. @@ -186,25 +381,3 @@ func getDnsServerAddr(c *config.C) string { } return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } - -func startDns(l *logrus.Logger, c *config.C) { - dnsAddr = getDnsServerAddr(c) - dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} - l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") - err := dnsServer.ListenAndServe() - defer dnsServer.Shutdown() - if err != nil { - l.Errorf("Failed to start server: %s\n ", err.Error()) - } -} - -func reloadDns(l *logrus.Logger, c *config.C) { - if dnsAddr == getDnsServerAddr(c) { - l.Debug("No DNS server config change detected") - return - } - - l.Debug("Restarting DNS server") - dnsServer.Shutdown() - go startDns(l, c) -} diff --git a/dns_server_test.go b/dns_server_test.go index 356e5890..dcea046c 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -1,19 +1,43 @@ package nebula import ( + "context" + "log/slog" + "net" "net/netip" + "strconv" "testing" + "time" "github.com/miekg/dns" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +type stubDNSWriter struct{} + +func (stubDNSWriter) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (stubDNSWriter) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5353} +} +func (stubDNSWriter) Write([]byte) (int, error) { return 0, nil } +func (stubDNSWriter) WriteMsg(*dns.Msg) error { return nil } +func (stubDNSWriter) Close() error { return nil } +func (stubDNSWriter) TsigStatus() error { return nil } +func (stubDNSWriter) TsigTimersOnly(bool) {} +func (stubDNSWriter) Hijack() {} + func TestParsequery(t *testing.T) { - l := logrus.New() + l := slog.New(slog.DiscardHandler) hostMap := &HostMap{} - ds := newDnsRecords(l, &CertState{}, hostMap) + ds := &dnsServer{ + l: l, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + } + ds.enabled.Store(true) addrs := []netip.Addr{ netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5"), @@ -21,18 +45,56 @@ func TestParsequery(t *testing.T) { netip.MustParseAddr("fd01::25"), } ds.Add("test.com.com", addrs) + ds.Add("v4only.com.com", []netip.Addr{netip.MustParseAddr("1.2.3.6")}) + ds.Add("v6only.com.com", []netip.Addr{netip.MustParseAddr("fd01::26")}) m := &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeA) ds.parseQuery(m, nil) assert.NotNil(t, m.Answer) assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) m = &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeAAAA) ds.parseQuery(m, nil) assert.NotNil(t, m.Answer) assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) + + // A known name with no record of the requested type should return NODATA + // (NOERROR with empty answer), not NXDOMAIN. + m = &dns.Msg{} + m.SetQuestion("v4only.com.com", dns.TypeAAAA) + ds.parseQuery(m, nil) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) + + m = &dns.Msg{} + m.SetQuestion("v6only.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) + + // An unknown name should still return NXDOMAIN. + m = &dns.Msg{} + m.SetQuestion("unknown.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeNameError, m.Rcode) + + // short lookups should not fail + m = &dns.Msg{} + m.Question = []dns.Question{{Name: "", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}} + ds.parseQuery(m, stubDNSWriter{}) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeNameError, m.Rcode) + + m = &dns.Msg{} + m.Question = []dns.Question{{Name: ".", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}} + ds.parseQuery(m, stubDNSWriter{}) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeNameError, m.Rcode) } func Test_getDnsServerAddr(t *testing.T) { @@ -71,3 +133,208 @@ func Test_getDnsServerAddr(t *testing.T) { } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) } + +func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) { + t.Helper() + sl := slog.New(slog.DiscardHandler) + ds := &dnsServer{ + l: sl, + ctx: context.Background(), + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: &HostMap{}, + } + ds.mux = dns.NewServeMux() + ds.mux.HandleFunc(".", ds.handleDnsRequest) + return ds, config.NewC(nil) +} + +func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) { + c.Settings["lighthouse"] = map[string]any{ + "am_lighthouse": amLighthouse, + "serve_dns": serveDns, + "dns": map[string]any{ + "host": host, + "port": port, + }, + } +} + +func TestDnsServer_reload_initial_disabled(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, false) + + require.NoError(t, ds.reload(c, true)) + assert.False(t, ds.enabled.Load()) + assert.Equal(t, "127.0.0.1:0", ds.addr) + assert.Nil(t, ds.server) +} + +func TestDnsServer_reload_initial_enabled(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, true) + + require.NoError(t, ds.reload(c, true)) + assert.True(t, ds.enabled.Load()) + assert.Equal(t, "127.0.0.1:0", ds.addr) + // initial never starts a runner; that's Control.Start's job + assert.Nil(t, ds.server) +} + +func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", false, true) + + require.NoError(t, ds.reload(c, true)) + // Wants DNS but isn't a lighthouse: gated off, no runner. + assert.False(t, ds.enabled.Load()) +} + +func TestDnsServer_reload_sameAddr_noOp(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, true) + + require.NoError(t, ds.reload(c, true)) + // No server running yet, no addr change. Reload should not spawn anything. + require.NoError(t, ds.reload(c, false)) + assert.True(t, ds.enabled.Load()) + assert.Nil(t, ds.server) +} + +func TestDnsServer_StartStop_lifecycle(t *testing.T) { + // Bind to a real (random) UDP port so we exercise the actual + // ListenAndServe + Shutdown plumbing including the started-chan race fix. + port := freeUDPPort(t) + + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", port, true, true) + require.NoError(t, ds.reload(c, true)) + + done := make(chan struct{}) + go func() { + ds.Start() + close(done) + }() + + waitFor(t, func() bool { + ds.serverMu.Lock() + started := ds.started + ds.serverMu.Unlock() + if started == nil { + return false + } + select { + case <-started: + return true + default: + return false + } + }) + + ds.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } +} + +func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) { + // Stop called immediately after Start should not deadlock even if bind + // hasn't completed yet. This exercises the started-chan close-on-bind-fail + // path: by binding to an obviously bad port (privileged) we get a fast + // bind error before NotifyStartedFunc fires. + ds, c := newTestDnsServer(t) + // Use a port that should fail to bind (negative would be invalid, use a + // host that won't resolve to ensure listenUDP fails quickly). + setDnsConfig(c, "256.256.256.256", "53", true, true) + require.NoError(t, ds.reload(c, true)) + + done := make(chan struct{}) + go func() { + ds.Start() + close(done) + }() + + // Give Start a moment to attempt the bind and fail. + select { + case <-done: + // Bind failed and Start returned; Stop should be a no-op. + case <-time.After(time.Second): + t.Fatal("Start did not return after a bad bind") + } + + stopped := make(chan struct{}) + go func() { + ds.Stop() + close(stopped) + }() + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatal("Stop hung after a failed bind") + } +} + +func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) { + port := freeUDPPort(t) + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", port, true, true) + require.NoError(t, ds.reload(c, true)) + + startReturned := make(chan struct{}) + go func() { + ds.Start() + close(startReturned) + }() + waitForBind(t, ds) + + // Toggle serve_dns off; reload should shut the running server down. + setDnsConfig(c, "127.0.0.1", port, true, false) + require.NoError(t, ds.reload(c, false)) + select { + case <-startReturned: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after reload disabled DNS") + } + assert.False(t, ds.enabled.Load()) +} + +func freeUDPPort(t *testing.T) string { + t.Helper() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + port := conn.LocalAddr().(*net.UDPAddr).Port + require.NoError(t, conn.Close()) + return strconv.Itoa(port) +} + +func waitForBind(t *testing.T, ds *dnsServer) { + t.Helper() + waitFor(t, func() bool { + ds.serverMu.Lock() + started := ds.started + ds.serverMu.Unlock() + if started == nil { + return false + } + select { + case <-started: + return true + default: + return false + } + }) +} + +func waitFor(t *testing.T, cond func() bool) { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatal("timed out waiting for condition") +} diff --git a/e2e/handshake_manager_test.go b/e2e/handshake_manager_test.go new file mode 100644 index 00000000..b06564d1 --- /dev/null +++ b/e2e/handshake_manager_test.go @@ -0,0 +1,577 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" + "github.com/stretchr/testify/assert" +) + +// makeHandshakePacket creates a handshake packet with the given parameters. +func makeHandshakePacket(from, to netip.AddrPort, subtype header.MessageSubType, remoteIndex uint32, counter uint64) *udp.Packet { + data := make([]byte, 200) + header.Encode(data, header.Version, header.Handshake, subtype, remoteIndex, counter) + for i := header.Len; i < len(data); i++ { + data[i] = byte(i) + } + return &udp.Packet{To: to, From: from, Data: data} +} + +func TestHandshakeRetransmitDuplicate(t *testing.T) { + t.Parallel() + // Verify the responder correctly handles receiving the same msg1 multiple times + // (retransmission). The duplicate goes through CheckAndComplete -> ErrAlreadySeen + // and the cached response is resent. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Trigger handshake from me to them") + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) + + t.Log("Grab my msg1") + msg1 := myControl.GetFromUDP(true) + + t.Log("Inject msg1 into them, first time") + theirControl.InjectUDPPacket(msg1) + _ = theirControl.GetFromUDP(true) + + t.Log("Inject the SAME msg1 again, tests ErrAlreadySeen path") + theirControl.InjectUDPPacket(msg1) + resp2 := theirControl.GetFromUDP(true) + assert.NotNil(t, resp2, "should get cached response on duplicate msg1") + + t.Log("Complete handshake with cached response") + myControl.InjectUDPPacket(resp2) + myControl.WaitForType(1, 0, theirControl) + + t.Log("Drain cached packet and verify tunnel works") + cachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Verify only one tunnel exists on each side") + assert.Len(t, myControl.ListHostmapHosts(false), 1) + assert.Len(t, theirControl.ListHostmapHosts(false), 1) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeTruncatedPacketRecovery(t *testing.T) { + t.Parallel() + // Verify that a truncated handshake packet is ignored and the real + // packet can still complete the handshake. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Trigger handshake") + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) + + t.Log("Get msg1 and deliver to responder") + msg1 := myControl.GetFromUDP(true) + theirControl.InjectUDPPacket(msg1) + + t.Log("Get the real response") + realResp := theirControl.GetFromUDP(true) + + t.Log("Truncate the response and inject, should be ignored") + truncResp := realResp.Copy() + truncResp.Data = truncResp.Data[:header.Len] + myControl.InjectUDPPacket(truncResp) + + t.Log("Verify pending handshake survived the truncated packet") + assert.NotEmpty(t, myControl.ListHostmapHosts(true), "pending handshake should still exist") + + t.Log("Inject real response, should complete handshake") + myControl.InjectUDPPacket(realResp) + myControl.WaitForType(1, 0, theirControl) + + t.Log("Drain and verify tunnel") + cachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeOrphanedMsg2Dropped(t *testing.T) { + t.Parallel() + // A msg2 arriving with no matching pending index should be silently dropped + // with no response sent and no state changes. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Complete a normal handshake") + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) + r.RouteForAllUntilTxTun(theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Record hostmap state") + myIndexes := len(myControl.ListHostmapIndexes(false)) + + t.Log("Inject a fake msg2 with unknown RemoteIndex") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0xDEADBEEF, 2)) + + t.Log("Verify no new indexes created") + assert.Equal(t, myIndexes, len(myControl.ListHostmapIndexes(false))) + + t.Log("Verify no UDP response was sent") + time.Sleep(100 * time.Millisecond) + assert.Nil(t, myControl.GetFromUDP(false), "should not send a response to orphaned msg2") + + t.Log("Verify existing tunnel still works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeUnknownMessageCounter(t *testing.T) { + t.Parallel() + // A handshake packet with an unexpected message counter should be silently + // dropped with no side effects and no UDP response. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + myControl.Start() + theirControl.Start() + + t.Log("Inject handshake with MessageCounter=3") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 3)) + + t.Log("Inject handshake with MessageCounter=99") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 99)) + + t.Log("Verify no tunnels or pending handshakes") + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + + t.Log("Verify no UDP response was sent") + time.Sleep(100 * time.Millisecond) + assert.Nil(t, myControl.GetFromUDP(false)) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeUnknownSubtype(t *testing.T) { + t.Parallel() + // A handshake packet with an unknown subtype should be silently dropped. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.Start() + theirControl.Start() + + t.Log("Inject handshake with unknown subtype 99") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.MessageSubType(99), 0, 1)) + + t.Log("Verify no tunnels or pending handshakes") + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + + t.Log("Verify no UDP response was sent") + time.Sleep(100 * time.Millisecond) + assert.Nil(t, myControl.GetFromUDP(false)) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeLateResponse(t *testing.T) { + t.Parallel() + // After a handshake times out, a late response should be silently ignored + // with no new tunnels created. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{ + "handshakes": m{ + "try_interval": "200ms", + "retries": 2, + }, + }) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + myControl.Start() + theirControl.Start() + + t.Log("Trigger handshake from me") + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) + + t.Log("Grab msg1 but don't deliver") + msg1 := myControl.GetFromUDP(true) + + t.Log("Wait for handshake to time out") + for i := 0; i < 5; i++ { + time.Sleep(300 * time.Millisecond) + myControl.GetFromUDP(false) + } + + t.Log("Confirm no pending handshakes remain") + assert.Empty(t, myControl.ListHostmapHosts(true)) + + t.Log("Deliver old msg1 to them, they create a tunnel") + theirControl.InjectUDPPacket(msg1) + resp := theirControl.GetFromUDP(true) + assert.NotNil(t, resp) + + t.Log("Inject late response into me, should be ignored") + myControl.InjectUDPPacket(resp) + + t.Log("No tunnel should exist on my side") + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeSelfConnectionRejected(t *testing.T) { + t.Parallel() + // Verify that a node rejects a handshake containing its own VPN IP in the + // peer cert. We do this by sending the initiator's own msg1 back to itself. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + + // Need a lighthouse entry to trigger a handshake + myControl.InjectLightHouseAddr(netip.MustParseAddr("10.128.0.2"), netip.MustParseAddrPort("10.0.0.2:4242")) + + myControl.Start() + + t.Log("Trigger handshake from me") + myControl.InjectTunPacket(BuildTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) + msg1 := myControl.GetFromUDP(true) + + t.Log("Drain any handshake retransmits before injecting") + time.Sleep(100 * time.Millisecond) + for myControl.GetFromUDP(false) != nil { + } + + t.Log("Feed my own msg1 back to me as if it came from someone else") + selfMsg := msg1.Copy() + selfMsg.From = netip.MustParseAddrPort("10.0.0.99:4242") + selfMsg.To = myUdpAddr + myControl.InjectUDPPacket(selfMsg) + + t.Log("Verify no response was sent (self-connection rejected)") + time.Sleep(100 * time.Millisecond) + // Drain any further retransmits from the original handshake, then check + // that none of them are a handshake response (MessageCounter=2) + h := &header.H{} + for { + p := myControl.GetFromUDP(false) + if p == nil { + break + } + _ = h.Parse(p.Data) + assert.NotEqual(t, uint64(2), h.MessageCounter, + "should not send a stage 2 response to self-connection") + } + + t.Log("Verify no tunnel to myself was created") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)) + + myControl.Stop() +} + +func TestHandshakeMessageCounter0Dropped(t *testing.T) { + t.Parallel() + // MessageCounter=0 is not a valid handshake message and should be dropped. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + _, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.Start() + + t.Log("Inject handshake with MessageCounter=0") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 0)) + + time.Sleep(100 * time.Millisecond) + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + assert.Nil(t, myControl.GetFromUDP(false)) + + myControl.Stop() +} + +func TestHandshakeRemoteAllowList(t *testing.T) { + t.Parallel() + // Verify that a handshake from a blocked underlay IP is dropped with no + // response and no state changes. Then verify the same packet from an + // allowed IP succeeds. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{ + "lighthouse": m{ + "remote_allow_list": m{ + "10.0.0.0/8": true, + "0.0.0.0/0": false, + }, + }, + }) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Trigger handshake from them") + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi"))) + msg1 := theirControl.GetFromUDP(true) + + t.Log("Rewrite the source to a blocked IP and inject") + blockedMsg := msg1.Copy() + blockedMsg.From = netip.MustParseAddrPort("192.168.1.1:4242") + myControl.InjectUDPPacket(blockedMsg) + + t.Log("Verify no tunnel, no pending, no response from blocked source") + time.Sleep(100 * time.Millisecond) + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + assert.Nil(t, myControl.GetFromUDP(false), "should not respond to blocked source") + + t.Log("Now inject the real packet from the allowed source") + myControl.InjectUDPPacket(msg1) + + t.Log("Verify handshake completes from allowed source") + resp := myControl.GetFromUDP(true) + assert.NotNil(t, resp) + theirControl.InjectUDPPacket(resp) + theirControl.WaitForType(1, 0, myControl) + + t.Log("Drain cached packet and verify tunnel works") + cachedPacket := myControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi"), cachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { + t.Parallel() + // When a duplicate msg1 arrives via ErrAlreadySeen, verify the tunnel + // remains functional and hostmap index count is stable. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Complete a normal handshake via the router") + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) + r.RouteForAllUntilTxTun(theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Record hostmap state") + theirIndexes := len(theirControl.ListHostmapIndexes(false)) + hi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, hi) + originalRemote := hi.CurrentRemote + + t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)") + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam"))) + r.RouteForAllUntilTxTun(theirControl) + + t.Log("Verify tunnel still works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Verify remote is still valid and index count is stable") + hi2 := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, hi2) + assert.Equal(t, originalRemote, hi2.CurrentRemote) + assert.Equal(t, theirIndexes, len(theirControl.ListHostmapIndexes(false)), + "no extra indexes should be created from ErrAlreadySeen") + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeWrongResponderPacketStore(t *testing.T) { + t.Parallel() + // Verify that when the wrong host responds, the cached packets are + // transferred to the new handshake, the evil tunnel is closed, evil's + // address is blocked, and the correct tunnel is eventually established. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIpNet, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) + + r := router.NewR(t, myControl, theirControl, evilControl) + defer r.RenderFlow() + + myControl.Start() + theirControl.Start() + evilControl.Start() + + t.Log("Send multiple packets to them (cached during handshake)") + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1"))) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2"))) + + t.Log("Route until evil tunnel is closed") + h := &header.H{} + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + if err := h.Parse(p.Data); err != nil { + panic(err) + } + if h.Type == header.CloseTunnel && p.To == evilUdpAddr { + return router.RouteAndExit + } + return router.KeepRouting + }) + + t.Log("Verify evil's address is blocked in the new pending handshake") + pendingHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) + if pendingHI != nil { + assert.NotContains(t, pendingHI.RemoteAddrs, evilUdpAddr, + "evil's address should be blocked") + } + + t.Log("Inject correct lighthouse addr for them") + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + t.Log("Route until cached packets arrive at the real them") + p := r.RouteForAllUntilTxTun(theirControl) + assert.NotNil(t, p, "a cached packet should be delivered to the correct host") + + t.Log("Verify the correct host has a tunnel") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + + t.Log("Verify no hostinfo artifacts from evil remain") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), true), + "no pending hostinfo for evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), false), + "no main hostinfo for evil") + + myControl.Stop() + theirControl.Stop() + evilControl.Stop() +} + +func TestHandshakeRelayComplete(t *testing.T) { + t.Parallel() + // Verify that a relay handshake completes correctly and relay state is + // properly maintained on all three nodes. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger handshake via relay") + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay"))) + + p := r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + + t.Log("Verify bidirectional tunnel via relay") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Verify relay state on my side shows relay-to-me") + myHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + assert.NotNil(t, myHI) + assert.NotEmpty(t, myHI.CurrentRelaysToMe, "should have relay-to-me for them") + + t.Log("Verify relay state on their side shows relay-to-me") + theirHI := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, theirHI) + assert.NotEmpty(t, theirHI.CurrentRelaysToMe, "should have relay-to-me for me") + + t.Log("Verify relay node shows through-me relays") + relayHI := relayControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, relayHI) + + myControl.Stop() + relayControl.Stop() + theirControl.Stop() +} + +// NOTE: Relay V1 cert + IPv6 rejection is not tested here because +// BuildTunUDPPacket from a V4 node to a V6 address panics in the test +// framework. The check is in handshake_manager.go handleOutbound relay +// logic (lines ~304-313): if the relay host has a V1 cert and either +// address is IPv6, the relay is skipped. + +// NOTE: Relay reestablishment (Disestablished state transition) is covered +// by the existing TestReestablishRelays in handshakes_test.go. diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 67b166b1..d0b9543c 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -11,12 +11,12 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -40,11 +40,22 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + // Pre-build the IP packet bytes once so the bench measures the data plane, + // not gopacket SerializeLayers overhead. + prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + // EnableFanIn switches the router to a 0-alloc routing path. Required + // for hot-path benchmarks; would conflict with GetFromUDP-using tests. + r.EnableFanIn() + b.ResetTimer() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - _ = r.RouteForAllUntilTxTun(theirControl) + myControl.InjectTunPacket(prebuilt) + // Release the TUN-side bytes back to the harness freelist; the bench + // just confirms a packet arrived, the contents aren't inspected. + overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl)) } myControl.Stop() @@ -72,11 +83,15 @@ func BenchmarkHotPathRelay(b *testing.B) { theirControl.Start() assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + + prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + r.EnableFanIn() + b.ResetTimer() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - _ = r.RouteForAllUntilTxTun(theirControl) + myControl.InjectTunPacket(prebuilt) + overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl)) } myControl.Stop() @@ -85,6 +100,7 @@ func BenchmarkHotPathRelay(b *testing.B) { } func TestGoodHandshake(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -97,7 +113,7 @@ func TestGoodHandshake(t *testing.T) { theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -135,6 +151,7 @@ func TestGoodHandshake(t *testing.T) { } func TestGoodHandshakeNoOverlap(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack! @@ -170,6 +187,7 @@ func TestGoodHandshakeNoOverlap(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) @@ -189,7 +207,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -246,6 +264,7 @@ func TestWrongResponderHandshake(t *testing.T) { } func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) @@ -270,7 +289,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -328,6 +347,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { } func TestStage1Race(t *testing.T) { + t.Parallel() // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel @@ -348,8 +368,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -408,6 +428,7 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -425,7 +446,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -436,7 +457,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))) p = r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -457,6 +478,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -474,7 +496,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -486,7 +508,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))) p = r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) @@ -508,6 +530,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) @@ -528,7 +551,7 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -537,6 +560,7 @@ func TestRelays(t *testing.T) { } func TestRelaysDontCareAboutIps(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}}) @@ -557,7 +581,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -566,6 +590,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) { } func TestReestablishRelays(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) @@ -586,14 +611,14 @@ func TestReestablishRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Ensure packet traversal from them to me via the relay") - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) p = r.RouteForAllUntilTxTun(myControl) r.Log("Assert the tunnel works") @@ -608,7 +633,7 @@ func TestReestablishRelays(t *testing.T) { for curIndexes >= start { curIndexes = len(myControl.GetHostmap().Indexes) r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes) - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { return router.RouteAndExit @@ -625,7 +650,7 @@ func TestReestablishRelays(t *testing.T) { myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p = r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -660,7 +685,7 @@ func TestReestablishRelays(t *testing.T) { t.Log("Assert the tunnel works the other way, too") for { t.Log("RouteForAllUntilTxTun") - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) p = r.RouteForAllUntilTxTun(myControl) r.Log("Assert the tunnel works") @@ -697,6 +722,7 @@ func TestReestablishRelays(t *testing.T) { } func TestStage1RaceRelays(t *testing.T) { + t.Parallel() //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) @@ -729,8 +755,8 @@ func TestStage1RaceRelays(t *testing.T) { assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -744,12 +770,12 @@ func TestStage1RaceRelays(t *testing.T) { } func TestStage1RaceRelays2(t *testing.T) { + t.Parallel() //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) - l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) @@ -771,49 +797,41 @@ func TestStage1RaceRelays2(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - l.Info("Get a tunnel between me and relay") assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - l.Info("Get a tunnel between them and relay") assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) - r.Log("Wait for a packet from them to me") - l.Info("Wait for a packet from them to me; myControl") + r.Log("Wait for a packet from them to me; myControl") r.RouteForAllUntilTxTun(myControl) - l.Info("Wait for a packet from them to me; theirControl") + r.Log("Wait for a packet from them to me; theirControl") r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") - l.Info("Wait until we remove extra tunnels") - l.WithFields( - logrus.Fields{ - "myControl": len(myControl.GetHostmap().Indexes), - "theirControl": len(theirControl.GetHostmap().Indexes), - "relayControl": len(relayControl.GetHostmap().Indexes), - }).Info("Waiting for hostinfos to be removed...") + t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d", + len(myControl.GetHostmap().Indexes), + len(theirControl.GetHostmap().Indexes), + len(relayControl.GetHostmap().Indexes), + ) hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) retries := 60 for hostInfos > 6 && retries > 0 { hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) - l.WithFields( - logrus.Fields{ - "myControl": len(myControl.GetHostmap().Indexes), - "theirControl": len(theirControl.GetHostmap().Indexes), - "relayControl": len(relayControl.GetHostmap().Indexes), - }).Info("Waiting for hostinfos to be removed...") + t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d", + len(myControl.GetHostmap().Indexes), + len(theirControl.GetHostmap().Indexes), + len(relayControl.GetHostmap().Indexes), + ) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) @@ -821,7 +839,6 @@ func TestStage1RaceRelays2(t *testing.T) { } r.Log("Assert the tunnel works") - l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() @@ -830,6 +847,7 @@ func TestStage1RaceRelays2(t *testing.T) { } func TestRehandshakingRelays(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) @@ -850,7 +868,7 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -933,6 +951,7 @@ func TestRehandshakingRelays(t *testing.T) { } func TestRehandshakingRelaysPrimary(t *testing.T) { + t.Parallel() // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) @@ -954,7 +973,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -1037,6 +1056,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) @@ -1132,6 +1152,7 @@ func TestRehandshaking(t *testing.T) { } func TestRehandshakingLoser(t *testing.T) { + t.Parallel() // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -1230,6 +1251,7 @@ func TestRehandshakingLoser(t *testing.T) { } func TestRaceRegression(t *testing.T) { + t.Parallel() // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo @@ -1253,8 +1275,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -1290,6 +1312,7 @@ func TestRaceRegression(t *testing.T) { } func TestV2NonPrimaryWithLighthouse(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) @@ -1330,6 +1353,7 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) { } func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}}) @@ -1369,7 +1393,84 @@ func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { theirControl.Stop() } +func TestLighthouseUpdateOnReload(t *testing.T) { + t.Parallel() + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + // Create the lighthouse + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh", "10.128.0.1/24", m{"lighthouse": m{"am_lighthouse": true}}) + + // Create a client with NO lighthouse configured and a long update interval. + // The initial SendUpdate at startup will be a no-op since no lighthouses are known. + myControl, myVpnIpNet, _, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.2/24", m{ + "lighthouse": m{ + "interval": 600, + "local_allow_list": m{ + "10.0.0.0/24": true, + "::/0": false, + }, + }, + }) + + r := router.NewR(t, lhControl, myControl) + defer r.RenderFlow() + + lhControl.Start() + myControl.Start() + + // Drain any startup packets (there should be none meaningful) + r.FlushAll() + + // Verify lighthouse has no knowledge of the client + assert.Nil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr())) + + // Build a new config that adds the lighthouse + newSettings := make(m) + for k, v := range myConfig.Settings { + newSettings[k] = v + } + newSettings["static_host_map"] = m{ + lhVpnIpNet[0].Addr().String(): []any{lhUdpAddr.String()}, + } + newSettings["lighthouse"] = m{ + "hosts": []any{lhVpnIpNet[0].Addr().String()}, + "interval": 600, + "local_allow_list": m{ + "10.0.0.0/24": true, + "::/0": false, + }, + } + newCfg, err := yaml.Marshal(newSettings) + require.NoError(t, err) + + // Reload the config. The lighthouse.hosts change triggers TriggerUpdate, + // which wakes the update worker. It calls SendUpdate, initiating a + // handshake to the new lighthouse and caching the HostUpdateNotification. + require.NoError(t, myConfig.ReloadConfigString(string(newCfg))) + + // Route until the lighthouse receives the HostUpdateNotification. + // This covers: handshake stage 1, stage 2, then the cached update. + done := make(chan struct{}) + go func() { + r.RouteForAllUntilAfterMsgTypeTo(lhControl, header.LightHouse, 0) + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for lighthouse update after config reload") + } + + // Verify lighthouse now has the client's addresses + assert.NotNil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr())) + + r.RenderHostmaps("Final hostmaps", lhControl, myControl) + lhControl.Stop() + myControl.Stop() +} + func TestGoodHandshakeUnsafeDest(t *testing.T) { + t.Parallel() unsafePrefix := "192.168.6.0/24" ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil) @@ -1391,7 +1492,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) { theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -1419,7 +1520,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) { assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80) //reply - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))) //wait for reply theirControl.WaitForType(1, 0, myControl) theirCachedPacket := myControl.GetFromTun(true) diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 39843efe..b555fbc4 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -4,7 +4,6 @@ package e2e import ( - "fmt" "io" "net/netip" "os" @@ -12,15 +11,18 @@ import ( "testing" "time" + "log/slog" + "dario.cat/mergo" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.yaml.in/yaml/v3" @@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific "port": udpAddr.Port(), }, "logging": m{ - "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), - "level": l.Level.String(), + "level": testLogLevelName(), }, "timers": m{ "pending_deletion_interval": 2, @@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o "port": udpAddr.Port(), }, "logging": m{ - "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()), - "level": l.Level.String(), + "level": testLogLevelName(), }, "timers": m{ "pending_deletion_interval": 2, @@ -294,12 +294,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me - controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) + controlB.InjectTunPacket(BuildTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))) bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them - controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) + controlA.InjectTunPacket(BuildTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))) aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } @@ -379,24 +379,87 @@ func getAddrs(ns []netip.Prefix) []netip.Addr { return a } -func NewTestLogger() *logrus.Logger { - l := logrus.New() - +func NewTestLogger() *slog.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(io.Discard) - l.SetLevel(logrus.PanicLevel) - return l + return slog.New(slog.NewTextHandler(io.Discard, nil)) } + level := slog.LevelInfo switch v { case "2": - l.SetLevel(logrus.DebugLevel) + level = slog.LevelDebug case "3": - l.SetLevel(logrus.TraceLevel) - default: - l.SetLevel(logrus.InfoLevel) + level = logging.LevelTrace + } + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} + +// testLogLevelName returns the level name string accepted by logging.ApplyConfig +// for the current TEST_LOGS setting. Kept in sync with NewTestLogger. +func testLogLevelName() string { + switch os.Getenv("TEST_LOGS") { + case "2": + return "debug" + case "3": + return "trace" + case "": + return "info" + } + return "info" +} + +// BuildTunUDPPacket assembles an IP+UDP packet suitable for Control.InjectTunPacket. +// Using UDP here because it's a simpler protocol. +func BuildTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) []byte { + serialize := make([]gopacket.SerializableLayer, 0) + var netLayer gopacket.NetworkLayer + if toAddr.Is6() { + if !fromAddr.Is6() { + panic("Cant send ipv6 to ipv4") + } + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } else { + if !fromAddr.Is4() { + panic("Cant send ipv4 to ipv6") + } + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip } - return l + udp := layers.UDP{ + SrcPort: layers.UDPPort(fromPort), + DstPort: layers.UDPPort(toPort), + } + if err := udp.SetNetworkLayerForChecksum(netLayer); err != nil { + panic(err) + } + + buffer := gopacket.NewSerializeBuffer() + opt := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + serialize = append(serialize, &udp, gopacket.Payload(data)) + if err := gopacket.SerializeLayers(buffer, opt, serialize...); err != nil { + panic(err) + } + + return buffer.Bytes() } diff --git a/e2e/leak_test.go b/e2e/leak_test.go new file mode 100644 index 00000000..ffb024fe --- /dev/null +++ b/e2e/leak_test.go @@ -0,0 +1,51 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" + "go.uber.org/goleak" +) + +// TestNoGoroutineLeaks brings up two nebula instances, completes a tunnel, +// stops both, and asserts no goroutines leak past the shutdown. goleak's +// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain +// before failing the assertion. +// +// IgnoreCurrent is necessary in the parallelized suite: other tests can +// leave goroutines mid-shutdown when this one runs (Stop is async, the +// wg.Wait() drain is not blocking on test return). We're checking that +// *this* test's setup tears down cleanly, not that the whole suite is +// idle at this moment. Intentionally NOT t.Parallel()'d for the same +// reason — concurrent test goroutines would always show up. +func TestNoGoroutineLeaks(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() + r.RenderFlow() + + // Settle period: Stop() is non-blocking; the wg-driven goroutines need + // a moment to drain. goleak retries internally too, but a short explicit + // settle reduces flakes when the suite is busy. + time.Sleep(50 * time.Millisecond) +} diff --git a/e2e/router/router.go b/e2e/router/router.go index c8264ab7..72012073 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -13,6 +13,7 @@ import ( "regexp" "sort" "sync" + "sync/atomic" "testing" "time" @@ -24,6 +25,19 @@ import ( "golang.org/x/exp/maps" ) +// outNatKey is the (from, to) pair used by outNat. Comparable struct, so it works as a map key without the +// allocation cost of a string-concat key. +type outNatKey struct { + from, to netip.AddrPort +} + +// fannedPacket pairs a UDP TX packet with its source control so the router can route it after popping from +// the fan-in channel. +type fannedPacket struct { + from *nebula.Control + pkt *udp.Packet +} + type R struct { // Simple map of the ip:port registered on a control to the control // Basically a router, right? @@ -34,12 +48,28 @@ type R struct { // A last used map, if an inbound packet hit the inNat map then // all return packets should use the same last used inbound address for the outbound sender - // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver - outNat map[string]netip.AddrPort + outNat map[outNatKey]netip.AddrPort // A map of vpn ip to the nebula control it belongs to vpnControls map[netip.Addr]*nebula.Control + // Cached select infrastructure for RouteForAllUntilTxTun. + // The controls map is immutable after NewR so the cases are good for the test lifetime. + // We only rebuild if a different receiver is asked. + selRecvCtl *nebula.Control + selCases []reflect.SelectCase + selCtls []*nebula.Control + + // Optional fan-in mode for hot-path benchmarks: one forwarder goroutine per control drains UDP TX into udpFanIn, + // so RouteForAllUntilTxTun can do a fixed 2-way native select instead of paying reflect.Select per call. + // Off by default (would otherwise interleave with tests that use GetFromUDP directly on the same control). + // Enabled by EnableFanIn. + udpFanIn chan fannedPacket + stopFanIn chan struct{} + fanInWG sync.WaitGroup + fanInMu sync.Mutex + fanInOn atomic.Bool + ignoreFlows []ignoreFlow flow []flowEntry @@ -119,7 +149,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { controls: make(map[netip.AddrPort]*nebula.Control), vpnControls: make(map[netip.Addr]*nebula.Control), inNat: make(map[netip.AddrPort]*nebula.Control), - outNat: make(map[string]netip.AddrPort), + outNat: make(map[outNatKey]netip.AddrPort), flow: []flowEntry{}, ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), @@ -153,8 +183,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { case <-ctx.Done(): return case <-clockSource.C: + r.Lock() r.renderHostmaps("clock tick") r.renderFlow() + r.Unlock() } } }() @@ -180,15 +212,21 @@ func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) { // RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening. func (r *R) RenderFlow() { r.cancelRender() + r.Lock() + defer r.Unlock() r.renderFlow() } // CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected func (r *R) CancelFlowLogs() { r.cancelRender() + r.Lock() r.flow = nil + r.Unlock() } +// renderFlow writes the flow log to disk. Caller must hold r.Lock. renderFlow reads r.flow / r.additionalGraphs and +// the *packet pointers stashed inside, all of which are mutated under the same lock by routing paths. func (r *R) renderFlow() { if r.flow == nil { return @@ -434,68 +472,157 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(sender, c, p, false) - c.InjectUDPPacket(p) + c.InjectUDPPacket(p) // copies internally; original is ours to release fp.WasReceived() r.Unlock() + p.Release() } } } -// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on receivers tun -// If the router doesn't have the nebula controller for that address, we panic +// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on the receiver's tun. +// If a control's UDP TX address can't be matched to a registered control, we panic. +// +// For allocation-sensitive callers (hot-path benchmarks, in particular relay +// benches with 3+ controls), call EnableFanIn() first. func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { + if r.fanInOn.Load() { + return r.routeFanIn(receiver) + } + return r.routeReflect(receiver) +} + +// routeFanIn is the alloc-free path used when EnableFanIn is in effect. +func (r *R) routeFanIn(receiver *nebula.Control) []byte { + tunTx := receiver.GetTunTxChan() + for { + select { + case p := <-tunTx: + r.Lock() + if r.flow != nil { + np := udp.Packet{Data: make([]byte, len(p))} + copy(np.Data, p) + r.unlockedInjectFlow(receiver, receiver, &np, true) + } + r.Unlock() + return p + case fp := <-r.udpFanIn: + r.routeUDP(fp.from, fp.pkt) + } + } +} + +// routeReflect is the default reflect.Select-based path. Pays the boxing allocation per call but doesn't interfere +// with tests that pull packets directly from controls' UDP TX channels via GetFromUDP. +func (r *R) routeReflect(receiver *nebula.Control) []byte { + sc, cm := r.selectCasesFor(receiver) + for { + x, rx, _ := reflect.Select(sc) + if x == 0 { + p := rx.Interface().([]byte) + r.Lock() + if r.flow != nil { + np := udp.Packet{Data: make([]byte, len(p))} + copy(np.Data, p) + r.unlockedInjectFlow(cm[x], cm[x], &np, true) + } + r.Unlock() + return p + } + r.routeUDP(cm[x], rx.Interface().(*udp.Packet)) + } +} + +// EnableFanIn switches RouteForAllUntilTxTun to the alloc-free fan-in path. +// One forwarder goroutine per registered control drains UDP TX into a shared channel that RouteForAllUntilTxTun selects +// on alongside the receiver's TUN TX channel. +func (r *R) EnableFanIn() { + r.fanInMu.Lock() + defer r.fanInMu.Unlock() + if r.fanInOn.Load() { + return + } + r.udpFanIn = make(chan fannedPacket, 32) + r.stopFanIn = make(chan struct{}) + for _, c := range r.controls { + r.startFanInWorker(c) + } + r.fanInOn.Store(true) + r.t.Cleanup(r.stopFanInWorkers) +} + +// startFanInWorker spawns a goroutine that drains c's UDP TX into r.udpFanIn. +func (r *R) startFanInWorker(c *nebula.Control) { + r.fanInWG.Add(1) + udpTx := c.GetUDPTxChan() + go func() { + defer r.fanInWG.Done() + for { + select { + case <-r.stopFanIn: + return + case p := <-udpTx: + select { + case <-r.stopFanIn: + p.Release() + return + case r.udpFanIn <- fannedPacket{from: c, pkt: p}: + } + } + } + }() +} + +// stopFanInWorkers signals the fan-in goroutines to exit and waits for them. +func (r *R) stopFanInWorkers() { + r.fanInMu.Lock() + wasOn := r.fanInOn.Swap(false) + r.fanInMu.Unlock() + if !wasOn { + return + } + close(r.stopFanIn) + r.fanInWG.Wait() +} + +// routeUDP forwards a UDP TX packet from the named source control to the destination control derived from p.To, +// releasing the source packet after InjectUDPPacket has copied its bytes into a fresh pool slot. +func (r *R) routeUDP(from *nebula.Control, p *udp.Packet) { + r.Lock() + defer r.Unlock() + a := from.GetUDPAddr() + c := r.getControl(a, p.To, p) + if c == nil { + panic(fmt.Sprintf("No control for udp tx %s", p.To)) + } + fp := r.unlockedInjectFlow(from, c, p, false) + c.InjectUDPPacket(p) // copies internally; original is ours to release + fp.WasReceived() + p.Release() +} + +// selectCasesFor returns the SelectCase array used by routeReflect: one slot for the receiver's TUN TX channel followed +// by one per control's UDP TX channel. Cached for the test lifetime, only rebuilt if the receiver changes. +func (r *R) selectCasesFor(receiver *nebula.Control) ([]reflect.SelectCase, []*nebula.Control) { + r.Lock() + defer r.Unlock() + if r.selRecvCtl == receiver && r.selCases != nil { + return r.selCases, r.selCtls + } sc := make([]reflect.SelectCase, len(r.controls)+1) cm := make([]*nebula.Control, len(r.controls)+1) - - i := 0 - sc[i] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(receiver.GetTunTxChan()), - Send: reflect.Value{}, - } - cm[i] = receiver - - i++ + sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(receiver.GetTunTxChan())} + cm[0] = receiver + i := 1 for _, c := range r.controls { - sc[i] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(c.GetUDPTxChan()), - Send: reflect.Value{}, - } - + sc[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan())} cm[i] = c i++ } - - for { - x, rx, _ := reflect.Select(sc) - r.Lock() - - if x == 0 { - // we are the tun tx, we can exit - p := rx.Interface().([]byte) - np := udp.Packet{Data: make([]byte, len(p))} - copy(np.Data, p) - - r.unlockedInjectFlow(cm[x], cm[x], &np, true) - r.Unlock() - return p - - } else { - // we are a udp tx, route and continue - p := rx.Interface().(*udp.Packet) - a := cm[x].GetUDPAddr() - c := r.getControl(a, p.To, p) - if c == nil { - r.Unlock() - panic(fmt.Sprintf("No control for udp tx %s", p.To)) - } - fp := r.unlockedInjectFlow(cm[x], c, p, false) - c.InjectUDPPacket(p) - fp.WasReceived() - } - r.Unlock() - } + r.selRecvCtl = receiver + r.selCases = sc + r.selCtls = cm + return sc, cm } // RouteExitFunc will call the whatDo func with each udp packet from sender. @@ -522,6 +649,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { switch e { case ExitNow: r.Unlock() + p.Release() return case RouteAndExit: @@ -529,6 +657,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { receiver.InjectUDPPacket(p) fp.WasReceived() r.Unlock() + p.Release() return case KeepRouting: @@ -541,6 +670,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { } r.Unlock() + p.Release() } } @@ -641,6 +771,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { switch e { case ExitNow: r.Unlock() + p.Release() return case RouteAndExit: @@ -648,6 +779,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { receiver.InjectUDPPacket(p) fp.WasReceived() r.Unlock() + p.Release() return case KeepRouting: @@ -659,6 +791,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { panic(fmt.Sprintf("Unknown exitFunc return: %v", e)) } r.Unlock() + p.Release() } } @@ -702,19 +835,20 @@ func (r *R) FlushAll() { } receiver.InjectUDPPacket(p) r.Unlock() + p.Release() } } // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { - if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { + if newAddr, ok := r.outNat[outNatKey{from: fromAddr, to: toAddr}]; ok { p.From = newAddr } c, ok := r.inNat[toAddr] if ok { - r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr + r.outNat[outNatKey{from: c.GetUDPAddr(), to: fromAddr}] = toAddr return c } diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index e89cf869..697f25af 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -12,11 +12,14 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v3" ) func TestDropInactiveTunnels(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -61,6 +64,7 @@ func TestDropInactiveTunnels(t *testing.T) { } func TestCertUpgrade(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -155,6 +159,7 @@ func TestCertUpgrade(t *testing.T) { } func TestCertDowngrade(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -253,6 +258,7 @@ func TestCertDowngrade(t *testing.T) { } func TestCertMismatchCorrection(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -320,6 +326,7 @@ func TestCertMismatchCorrection(t *testing.T) { } func TestCrossStackRelaysWork(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}}) @@ -348,14 +355,14 @@ func TestCrossStackRelaysWork(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80) t.Log("reply?") - theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))) p = r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80) @@ -365,3 +372,107 @@ func TestCrossStackRelaysWork(t *testing.T) { //theirControl.Stop() //relayControl.Stop() } + +func TestCloseTunnelAuthenticated(t *testing.T) { + t.Parallel() + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + + r.Log("Assert the tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + r.Log("Close the tunnel") + myControl.CloseTunnel(theirVpnIpNet[0].Addr(), false) + r.FlushAll() + + waitStart := time.Now() + for { + myIndexes := len(myControl.GetHostmap().Indexes) + theirIndexes := len(theirControl.GetHostmap().Indexes) + if myIndexes == 0 && theirIndexes == 0 { + break + } + + since := time.Since(waitStart) + r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since) + if since > time.Second*6 { + t.Fatal("Tunnel should have been declared inactive after 2 seconds and before 6 seconds") + } + + time.Sleep(1 * time.Second) + //r.FlushAll() + } + + r.Logf("Happy path success, tunnels were dropped within %v", time.Since(waitStart)) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + r.Log("Assert another tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + hi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + if hi == nil { + t.Fatal("There is no hostinfo for this tunnel") + } + myHi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + if myHi == nil { + t.Fatal("There is no hostinfo for my tunnel") + } + r.Log("It does") + + buf := make([]byte, 1024) + hdr := header.H{ + Version: 1, + Type: header.CloseTunnel, + Subtype: 0, + Reserved: 0, + RemoteIndex: hi.RemoteIndex, + MessageCounter: 5, + } + out, err := hdr.Encode(buf) + if err != nil { + t.Fatal(err) + } + + pkt := &udp.Packet{ + To: hi.CurrentRemote, + From: myHi.CurrentRemote, + Data: out, + } + r.InjectUDPPacket(myControl, theirControl, pkt) + r.Log("Injected bogus close tunnel. Let's see!") + waitStart = time.Now() + for { + myIndexes := len(myControl.GetHostmap().Indexes) + theirIndexes := len(theirControl.GetHostmap().Indexes) + if myIndexes == 0 { + t.Fatal("myIndexes should not be 0") + } + if theirIndexes == 0 { + t.Fatal("theirIndexes should not be 0, they should have rejected this bogus packet") + } + + since := time.Since(waitStart) + r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since) + if since > time.Second*4 { + t.Log("The tunnel would have been gone by now") + break + } + + time.Sleep(1 * time.Second) + r.FlushAll() + } + + myControl.Stop() + theirControl.Stop() +} diff --git a/examples/config.yml b/examples/config.yml index b713be4c..963307d7 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -204,6 +204,12 @@ punchy: # Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access. #trusted_cas: #- "ssh public key string" + # sandbox_dir restricts file paths for profiling commands (start-cpu-profile, save-heap-profile, + # save-mutex-profile) to the specified directory. Relative paths will be resolved within this directory, + # and absolute paths outside of it will be rejected. Default is $TMP/nebula-debug. + # The directory is NOT automatically created. + # Overriding this to "" is the same as "/" and will allow overwriting any path on the host. + #sandbox_dir: /var/tmp/nebula-debug # EXPERIMENTAL: relay support for networks that can't establish direct connections. relay: @@ -327,24 +333,21 @@ tun: # Configure logging level logging: - # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. - #NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some - # scenarios. Debug logging is also CPU intensive and will decrease performance overall. - # Only enable debug logging while actively investigating an issue. + # trace, debug, info, warn, or error. Default is info and is reloadable. + # fatal and panic are accepted for backwards compatibility and map to error. + #NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some + # scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall. + # Only enable debug or trace logging while actively investigating an issue. level: info - # json or text formats currently available. Default is text + # json or text formats currently available. Default is text. format: text - # Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false + # Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false. #disable_timestamp: true - # timestamp format is specified in Go time format, see: - # https://golang.org/pkg/time/#pkg-constants - # default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339) - # default when `format: text`: - # when TTY attached: seconds since beginning of execution - # otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339) - # As an example, to log as RFC3339 with millisecond precision, set to: - #timestamp_format: "2006-01-02T15:04:05.000Z07:00" + # Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable. +# The stats section is reloadable. A HUP may change the backend, toggle stats +# on or off, switch the listen/host address, or pick up new DNS for the +# configured graphite host. #stats: #type: graphite #prefix: nebula @@ -362,10 +365,12 @@ logging: # enables counter metrics for meta packets # e.g.: `messages.tx.handshake` # NOTE: `message.{tx,rx}.recv_error` is always emitted + # Not reloadable. #message_metrics: false # enables detailed counter metrics for lighthouse packets # e.g.: `lighthouse.rx.HostQuery` + # Not reloadable. #lighthouse_metrics: false # Handshake Manager Settings @@ -423,8 +428,8 @@ firewall: # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr) # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). - # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` # proto: `any`, `tcp`, `udp`, or `icmp` + # a port specification is ignored if proto is `icmp` # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass diff --git a/examples/go_service/main.go b/examples/go_service/main.go index 2f8efbfb..3f98fe3d 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -7,9 +7,9 @@ import ( "net" "os" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/service" ) @@ -64,8 +64,7 @@ pki: return err } - logger := logrus.New() - logger.Out = os.Stdout + logger := logging.NewLogger(os.Stdout) ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { diff --git a/firewall.go b/firewall.go index 45dc0691..adecbe81 100644 --- a/firewall.go +++ b/firewall.go @@ -1,11 +1,13 @@ package nebula import ( + "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "hash/fnv" + "log/slog" "net/netip" "reflect" "slices" @@ -16,7 +18,6 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -67,7 +68,7 @@ type Firewall struct { incomingMetrics firewallMetrics outgoingMetrics firewallMetrics - l *logrus.Logger + l *slog.Logger } type firewallMetrics struct { @@ -131,7 +132,7 @@ type firewallLocalCIDR struct { // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // The certificate provided should be the highest version loaded in memory. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { +func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration var tmin, tmax time.Duration @@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) if certificate == nil { certificate = cs.getCertificate(cert.Version1) @@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew case "drop": fw.InSendReject = false default: - l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`") + l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction) fw.InSendReject = false } @@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew case "drop": fw.OutSendReject = false default: - l.WithField("action", inboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`") + l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction) fw.OutSendReject = false } @@ -249,20 +250,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew // AddRule properly creates the in memory rule structure for a firewall table. func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { - // We need this rule string because we generate a hash. Removing this will break firewall reload. - ruleString := fmt.Sprintf( - "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", - incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, - ) - f.rules += ruleString + "\n" - - direction := "incoming" - if !incoming { - direction = "outgoing" - } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). - Info("Firewall rule added") - var ( ft *FirewallTable fp firewallPort @@ -280,6 +267,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort case firewall.ProtoUDP: fp = ft.UDP case firewall.ProtoICMP, firewall.ProtoICMPv6: + //ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided + if startPort != firewall.PortAny { + f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort) + } + startPort = firewall.PortAny + endPort = firewall.PortAny fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto @@ -287,6 +280,21 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort return fmt.Errorf("unknown protocol %v", proto) } + // We need this rule string because we generate a hash. Removing this will break firewall reload. + ruleString := fmt.Sprintf( + "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", + incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, + ) + f.rules += ruleString + "\n" + + direction := "incoming" + if !incoming { + direction = "outgoing" + } + f.l.Info("Firewall rule added", + "firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}, + ) + return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) } @@ -308,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } -func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { +func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { table = "firewall.inbound" @@ -349,24 +357,31 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw sPort = r.Port } - startPort, endPort, err := parsePort(sPort) - if err != nil { - return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) - } - var proto uint8 + var startPort, endPort int32 switch r.Proto { case "any": proto = firewall.ProtoAny + startPort, endPort, err = parsePort(sPort) case "tcp": proto = firewall.ProtoTCP + startPort, endPort, err = parsePort(sPort) case "udp": proto = firewall.ProtoUDP + startPort, endPort, err = parsePort(sPort) case "icmp": proto = firewall.ProtoICMP + startPort = firewall.PortAny + endPort = firewall.PortAny + if sPort != "" { + l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort) + } default: return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } + if err != nil { + return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) + } if r.Cidr != "" && r.Cidr != "any" { _, err = netip.ParsePrefix(r.Cidr) @@ -383,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw } if warning := r.sanity(); warning != nil { - l.Warnf("%s rule #%v; %s", table, i, warning) + l.Warn("firewall rule sanity check", + "table", table, + "rule", i, + "warning", warning, + ) } err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) @@ -467,7 +486,7 @@ func (f *Firewall) metrics(incoming bool) firewallMetrics { } } -// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new +// Destroy cleans up any known cyclical references so the object can be freed by GC. This should be called if a new // firewall object is created func (f *Firewall) Destroy() { //TODO: clean references if/when needed @@ -515,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, // We now know which firewall table to check against if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { - if f.l.Level >= logrus.DebugLevel { - h.logger(f.l). - WithField("fwPacket", fp). - WithField("incoming", c.incoming). - WithField("rulesVersion", f.rulesVersion). - WithField("oldRulesVersion", c.rulesVersion). - Debugln("dropping old conntrack entry, does not match new ruleset") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset", + "fwPacket", fp, + "incoming", c.incoming, + "rulesVersion", f.rulesVersion, + "oldRulesVersion", c.rulesVersion, + ) } delete(conntrack.Conns, fp) conntrack.Unlock() return false } - if f.l.Level >= logrus.DebugLevel { - h.logger(f.l). - WithField("fwPacket", fp). - WithField("incoming", c.incoming). - WithField("rulesVersion", f.rulesVersion). - WithField("oldRulesVersion", c.rulesVersion). - Debugln("keeping old conntrack entry, does match new ruleset") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset", + "fwPacket", fp, + "incoming", c.incoming, + "rulesVersion", f.rulesVersion, + "oldRulesVersion", c.rulesVersion, + ) } c.rulesVersion = f.rulesVersion @@ -660,6 +679,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer return false } + // this branch is here to catch traffic from FirewallTable.Any.match and FirewallTable.ICMP.match + if p.Protocol == firewall.ProtoICMP || p.Protocol == firewall.ProtoICMPv6 { + // port numbers are re-used for connection tracking of ICMP, + // but we don't want to actually filter on them. + return fp[firewall.PortAny].match(p, c, caPool) + } + var port int32 if p.Fragment { @@ -804,10 +830,8 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool { return true } - for _, group := range groups { - if group == "any" { - return true - } + if slices.Contains(groups, "any") { + return true } if host == "any" { @@ -917,7 +941,7 @@ type rule struct { CASha string } -func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { +func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) { r := rule{} m, ok := p.(map[string]any) @@ -948,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { return r, errors.New("group should contain a single value, an array with more than one entry was provided") } - l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) + l.Warn("group was an array with a single value, converting to simple value", + "table", table, + "rule", i, + ) m["group"] = v[0] } @@ -1018,54 +1045,56 @@ func (r *rule) sanity() error { } } + if r.Code != "" { + return fmt.Errorf("code specified as [%s]. Support for 'code' will be dropped in a future release, as it has never been functional", r.Code) + } + //todo alert on cidr-any return nil } -func parsePort(s string) (startPort, endPort int32, err error) { +func parsePort(s string) (int32, int32, error) { + var err error + const notAPort int32 = -2 if s == "any" { - startPort = firewall.PortAny - endPort = firewall.PortAny - - } else if s == "fragment" { - startPort = firewall.PortFragment - endPort = firewall.PortFragment - - } else if strings.Contains(s, `-`) { - sPorts := strings.SplitN(s, `-`, 2) - sPorts[0] = strings.Trim(sPorts[0], " ") - sPorts[1] = strings.Trim(sPorts[1], " ") - - if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { - return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) - } - - rStartPort, err := strconv.Atoi(sPorts[0]) - if err != nil { - return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) - } - - rEndPort, err := strconv.Atoi(sPorts[1]) - if err != nil { - return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) - } - - startPort = int32(rStartPort) - endPort = int32(rEndPort) - - if startPort == firewall.PortAny { - endPort = firewall.PortAny - } - - } else { + return firewall.PortAny, firewall.PortAny, nil + } + if s == "fragment" { + return firewall.PortFragment, firewall.PortFragment, nil + } + if !strings.Contains(s, `-`) { rPort, err := strconv.Atoi(s) if err != nil { - return 0, 0, fmt.Errorf("was not a number; `%s`", s) + return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s) } - startPort = int32(rPort) - endPort = startPort + return int32(rPort), int32(rPort), nil } - return + sPorts := strings.SplitN(s, `-`, 2) + for i := range sPorts { + sPorts[i] = strings.Trim(sPorts[i], " ") + } + if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { + return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) + } + + rStartPort, err := strconv.Atoi(sPorts[0]) + if err != nil { + return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) + } + + rEndPort, err := strconv.Atoi(sPorts[1]) + if err != nil { + return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) + } + + startPort := int32(rStartPort) + endPort := int32(rEndPort) + + if startPort == firewall.PortAny { + endPort = firewall.PortAny + } + + return startPort, endPort, nil } diff --git a/firewall/cache.go b/firewall/cache.go index 71b83f43..ba4b9732 100644 --- a/firewall/cache.go +++ b/firewall/cache.go @@ -1,10 +1,10 @@ package firewall import ( + "context" + "log/slog" "sync/atomic" "time" - - "github.com/sirupsen/logrus" ) // ConntrackCache is used as a local routine cache to know if a given flow @@ -15,41 +15,49 @@ type ConntrackCacheTicker struct { cacheV uint64 cacheTick atomic.Uint64 + l *slog.Logger cache ConntrackCache } -func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { +func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker { if d == 0 { return nil } c := &ConntrackCacheTicker{ + l: l, cache: ConntrackCache{}, } - go c.tick(d) + go c.tick(ctx, d) return c } -func (c *ConntrackCacheTicker) tick(d time.Duration) { +func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) { + t := time.NewTicker(d) + defer t.Stop() for { - time.Sleep(d) - c.cacheTick.Add(1) + select { + case <-ctx.Done(): + return + case <-t.C: + c.cacheTick.Add(1) + } } } // Get checks if the cache ticker has moved to the next version before returning // the map. If it has moved, we reset the map. -func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { +func (c *ConntrackCacheTicker) Get() ConntrackCache { if c == nil { return nil } if tick := c.cacheTick.Load(); tick != c.cacheV { c.cacheV = tick if ll := len(c.cache); ll > 0 { - if l.Level == logrus.DebugLevel { - l.WithField("len", ll).Debug("resetting conntrack cache") + if c.l.Enabled(context.Background(), slog.LevelDebug) { + c.l.Debug("resetting conntrack cache", "len", ll) } c.cache = make(ConntrackCache, ll) } diff --git a/firewall/cache_test.go b/firewall/cache_test.go new file mode 100644 index 00000000..ab807984 --- /dev/null +++ b/firewall/cache_test.go @@ -0,0 +1,69 @@ +package firewall + +import ( + "bytes" + "log/slog" + "strings" + "testing" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" +) + +// The tests below pin the log format produced by ConntrackCacheTicker.Get +// so changes cannot silently break what operators are grepping for. The +// ticker's internal state (cache + cacheTick) is poked directly to avoid +// racing a goroutine-driven tick in tests. + +func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker { + t.Helper() + c := &ConntrackCacheTicker{ + l: l, + cache: make(ConntrackCache, cacheLen), + } + for i := 0; i < cacheLen; i++ { + c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{} + } + c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path + return c +} + +func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 3) + c.Get() + + assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String()) +} + +func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 2) + c.Get() + + assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String())) +} + +func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo) + + c := newFixedTicker(t, l, 5) + c.Get() + + assert.Empty(t, buf.String()) +} + +func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 0) + c.Get() + + assert.Empty(t, buf.String()) +} diff --git a/firewall/packet.go b/firewall/packet.go index f4b4ea17..943210b0 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -23,7 +23,10 @@ const ( type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr - LocalPort uint16 + // LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP. + LocalPort uint16 + // RemotePort is the source port for incoming traffic, or the destination port for outgoing. + // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier RemotePort uint16 Protocol uint8 Fragment bool @@ -47,6 +50,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) { proto = "tcp" case ProtoICMP: proto = "icmp" + case ProtoICMPv6: + proto = "icmpv6" case ProtoUDP: proto = "udp" default: diff --git a/firewall_test.go b/firewall_test.go index 1df62a81..40b57477 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -3,13 +3,13 @@ package nebula import ( "bytes" "errors" + "log/slog" "math" "net/netip" "testing" "time" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) { } func TestFirewall_AddRule(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) @@ -87,9 +86,10 @@ func TestFirewall_AddRule(t *testing.T) { fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) - assert.Nil(t, fw.InRules.ICMP[1].Any.Any) - assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) - assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") + //no matter what port is given for icmp, it should end up as "any" + assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any) + assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups) + assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) @@ -176,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) { } func TestFirewall_Drop(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ @@ -253,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) { } func TestFirewall_DropV6(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) @@ -484,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) { } func TestFirewall_Drop2(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -543,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) { } func TestFirewall_Drop3(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -632,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) { } func TestFirewall_Drop3V6(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) @@ -670,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) { } func TestFirewall_DropConntrackReload(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -734,10 +728,152 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) } -func TestFirewall_DropIPSpoofing(t *testing.T) { - l := test.NewLogger() +func TestFirewall_ICMPPortBehavior(t *testing.T) { ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) + + network := netip.MustParsePrefix("1.2.3.4/24") + + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + }, + InvertedGroups: map[string]struct{}{"default-group": {}}, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + vpnAddrs: []netip.Addr{network.Addr()}, + } + h.buildNetworks(myVpnNetworksTable, c.Certificate) + + cp := cert.NewCAPool() + + templ := firewall.Packet{ + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), + Protocol: firewall.ProtoICMP, + Fragment: false, + } + + t.Run("ICMP allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + }) + + t.Run("nonzero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + }) + }) + + t.Run("Any proto, some ports allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 80 + p.RemotePort = 80 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + }) + t.Run("Any proto, any port", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, allowed", func(t *testing.T) { + resetConntrack(fw) + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + }) + + t.Run("nonzero ports, allowed", func(t *testing.T) { + resetConntrack(fw) + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + //different ID is blocked + p.RemotePort++ + require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + }) + +} + +func TestFirewall_DropIPSpoofing(t *testing.T) { + ob := &bytes.Buffer{} + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) @@ -897,56 +1033,56 @@ func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} - cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil, "aes") require.NoError(t, err) - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error - conf = config.NewC(l) - conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}} + conf = config.NewC(test.NewLogger()) + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") - conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") @@ -955,28 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) { l := test.NewLogger() // Test adding tcp rule - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) mf := &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding udp rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) + + // Test adding icmp rule no port + conf = config.NewC(test.NewLogger()) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding any rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) @@ -984,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding rule with cidr cidr := netip.MustParsePrefix("10.0.0.0/8") - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) // Test adding rule with local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) @@ -999,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding rule with cidr ipv6 cidr6 := netip.MustParsePrefix("fd00::/8") - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) // Test adding rule with any cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) // Test adding rule with junk cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with local_cidr ipv6 - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) // Test adding rule with any local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) // Test adding rule with junk local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with ca_sha - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) // Test single group - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test single groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test multiple AND groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) // Test Add error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} @@ -1082,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { } func TestFirewall_convertRule(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) // Ensure group array of 1 is converted and a warning is printed c := map[string]any{ @@ -1092,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) { } r, err := convertRule(l, c, "test", 1) - assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") + assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value") + assert.Contains(t, ob.String(), "table=test") + assert.Contains(t, ob.String(), "rule=1") require.NoError(t, err) assert.Equal(t, []string{"group1"}, r.Groups) @@ -1118,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) { } func TestFirewall_convertRuleSanity(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) noWarningPlease := []map[string]any{ {"group": "group1"}, @@ -1234,7 +1377,7 @@ type testsetup struct { fw *Firewall } -func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup { +func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup { c := dummyCert{ name: "me", networks: myPrefixes, @@ -1245,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse return newSetupFromCert(t, l, c) } -func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { +func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup { myVpnNetworksTable := new(bart.Lite) for _, prefix := range c.Networks() { myVpnNetworksTable.Insert(prefix) @@ -1262,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { t.Parallel() - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myPrefix := netip.MustParsePrefix("1.1.1.1/8") // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out diff --git a/go.mod b/go.mod index 1c564d03..24d901c5 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,10 @@ module github.com/slackhq/nebula -go 1.25 +go 1.25.0 require ( dario.cat/mergo v1.0.2 + filippo.io/bigmod v0.1.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 @@ -12,26 +13,26 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.4 - github.com/miekg/dns v1.1.70 - github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b + github.com/miekg/dns v1.1.72 + github.com/miekg/pkcs11 v1.1.2 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.23.2 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 - github.com/sirupsen/logrus v1.9.4 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.11.1 github.com/vishvananda/netlink v1.3.1 + go.uber.org/goleak v1.3.0 go.yaml.in/yaml/v3 v3.0.4 - golang.org/x/crypto v0.47.0 + golang.org/x/crypto v0.50.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.49.0 - golang.org/x/sync v0.19.0 - golang.org/x/sys v0.40.0 - golang.org/x/term v0.39.0 + golang.org/x/net v0.52.0 + golang.org/x/sync v0.20.0 + golang.org/x/sys v0.43.0 + golang.org/x/term v0.42.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b - golang.zx2c4.com/wireguard/windows v0.5.3 + golang.zx2c4.com/wireguard/windows v0.6.1 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe @@ -49,7 +50,7 @@ require ( github.com/prometheus/procfs v0.16.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect - golang.org/x/mod v0.31.0 // indirect + golang.org/x/mod v0.34.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.40.0 // indirect + golang.org/x/tools v0.43.0 // indirect ) diff --git a/go.sum b/go.sum index c4613e01..aad164c7 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +filippo.io/bigmod v0.1.0 h1:UNzDk7y9ADKST+axd9skUpBQeW7fG2KrTZyOE4uGQy8= +filippo.io/bigmod v0.1.0/go.mod h1:OjOXDNlClLblvXdwgFFOQFJEocLhhtai8vGLy0JCZlI= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -83,10 +85,10 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA= -github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= -github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= -github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= +github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= +github.com/miekg/pkcs11 v1.1.2 h1:/VxmeAX5qU6Q3EwafypogwWbYryHFmF2RpkJmw3m4MQ= +github.com/miekg/pkcs11 v1.1.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -131,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= -github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= @@ -162,16 +162,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -191,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= -golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -223,8 +223,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -233,8 +233,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= -golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= -golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= +golang.zx2c4.com/wireguard/windows v0.6.1 h1:XMaKojH1Hs/raMrmnir4n35nTvzvWj7NmSYzHn2F4qU= +golang.zx2c4.com/wireguard/windows v0.6.1/go.mod h1:04aqInu5GYuTFvMuDw/rKBAF7mHrltW/3rekpfbbZDM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= diff --git a/handshake/credential.go b/handshake/credential.go new file mode 100644 index 00000000..f6cd5f41 --- /dev/null +++ b/handshake/credential.go @@ -0,0 +1,57 @@ +package handshake + +import ( + "crypto/rand" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" +) + +// Credential holds everything needed to participate in a handshake +// at a given cert version. Version and Curve are read from Cert; the public +// half of the static keypair likewise comes from Cert.PublicKey(). +type Credential struct { + Cert cert.Certificate // the certificate + Bytes []byte // pre-marshaled certificate bytes + privateKey []byte // static private key (public half lives in Cert) + cipherSuite noise.CipherSuite // pre-built cipher suite (DH + cipher + hash) +} + +// NewCredential creates a Credential with all material needed for handshake +// participation. The cipherSuite should be pre-built by the caller with the +// appropriate DH function, cipher, and hash. +func NewCredential( + c cert.Certificate, + hsBytes []byte, + privateKey []byte, + cipherSuite noise.CipherSuite, +) *Credential { + return &Credential{ + Cert: c, + Bytes: hsBytes, + privateKey: privateKey, + cipherSuite: cipherSuite, + } +} + +// buildHandshakeState creates a noise.HandshakeState from this credential. +func (hc *Credential) buildHandshakeState(initiator bool, pattern noise.HandshakePattern) (*noise.HandshakeState, error) { + return noise.NewHandshakeState(noise.Config{ + CipherSuite: hc.cipherSuite, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: noise.DHKey{Private: hc.privateKey, Public: hc.Cert.PublicKey()}, + PresharedKey: []byte{}, + PresharedKeyPlacement: 0, + }) +} + +// GetCredentialFunc returns the handshake credential for the given version, +// or nil if that version is not available. +// +// Implementations must return credentials drawn from a snapshot stable for +// the lifetime of any single Machine. The Machine may call this multiple +// times during a handshake (e.g. when negotiating to the peer's version) +// and assumes the underlying static keypair is consistent across calls. +type GetCredentialFunc func(v cert.Version) *Credential diff --git a/handshake/errors.go b/handshake/errors.go new file mode 100644 index 00000000..bb8a5893 --- /dev/null +++ b/handshake/errors.go @@ -0,0 +1,21 @@ +package handshake + +import "errors" + +var ( + ErrInitiateOnResponder = errors.New("initiate called on responder") + ErrInitiateAlreadyCalled = errors.New("initiate already called") + ErrInitiateNotCalled = errors.New("initiate must be called before ProcessPacket for initiators") + ErrPacketTooShort = errors.New("packet too short") + ErrPublicKeyMismatch = errors.New("public key mismatch between certificate and handshake") + ErrIncompleteHandshake = errors.New("handshake completed without receiving required content") + ErrMachineFailed = errors.New("handshake machine has failed") + ErrUnknownSubtype = errors.New("unknown handshake subtype") + ErrMissingContent = errors.New("expected handshake content but message was empty") + ErrUnexpectedContent = errors.New("received unexpected handshake content") + ErrIndexAllocation = errors.New("failed to allocate local index") + ErrNoCredential = errors.New("no handshake credential available for cert version") + ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key") + ErrMultiMessageUnsupported = errors.New("multi-message handshake patterns are not yet supported by the manager") + ErrSubtypeMismatch = errors.New("packet subtype does not match handshake machine subtype") +) diff --git a/handshake/handshake.proto b/handshake/handshake.proto new file mode 100644 index 00000000..72d3b211 --- /dev/null +++ b/handshake/handshake.proto @@ -0,0 +1,37 @@ +// This file documents the wire format the nebula handshake speaks. It is +// not run through protoc; the encoder/decoder in payload.go is hand-written +// against this shape directly to keep the parser narrow and panic-free. +// +// Any change to the wire format must be reflected here, and adding a new +// field requires updating MarshalPayload / unmarshalPayloadDetails together +// with the field-uniqueness and wire-type checks in those functions. + +syntax = "proto3"; +package nebula.handshake; + +message NebulaHandshake { + NebulaHandshakeDetails Details = 1; + bytes Hmac = 2; +} + +message NebulaHandshakeDetails { + bytes Cert = 1; + uint32 InitiatorIndex = 2; + uint32 ResponderIndex = 3; + // Cookie was reserved for an anti-DoS mechanism that was never + // implemented. No released version of nebula has ever populated it; the + // hand-written parser silently skips it on read. + uint64 Cookie = 4 [deprecated = true]; + uint64 Time = 5; + uint32 CertVersion = 8; + + MultiPortDetails InitiatorMultiPort = 6; + MultiPortDetails ResponderMultiPort = 7; +} + +message MultiPortDetails { + bool RxSupported = 1; + bool TxSupported = 2; + uint32 BasePort = 3; + uint32 TotalPorts = 4; +} diff --git a/handshake/helpers_test.go b/handshake/helpers_test.go new file mode 100644 index 00000000..c72346cb --- /dev/null +++ b/handshake/helpers_test.go @@ -0,0 +1,116 @@ +package handshake + +import ( + "net/netip" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + ct "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/header" + "github.com/stretchr/testify/require" +) + +// testCertState holds cert material for a test peer. +type testCertState struct { + version cert.Version + creds map[cert.Version]*Credential +} + +func (s *testCertState) getCredential(v cert.Version) *Credential { + return s.creds[v] +} + +func newTestCertState( + t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix, +) *testCertState { + return newTestCertStateWithCipher(t, ca, caKey, name, networks, noise.CipherChaChaPoly) +} + +func newTestCertStateWithCipher( + t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix, + cipher noise.CipherFunc, +) *testCertState { + t.Helper() + c, _, rawPrivKey, _ := ct.NewTestCert( + cert.Version2, cert.Curve_CURVE25519, ca, caKey, + name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil, + ) + + priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawPrivKey) + require.NoError(t, err) + + hsBytes, err := c.MarshalForHandshakes() + require.NoError(t, err) + + ncs := noise.NewCipherSuite(noise.DH25519, cipher, noise.HashSHA256) + return &testCertState{ + version: cert.Version2, + creds: map[cert.Version]*Credential{ + cert.Version2: NewCredential(c, hsBytes, priv, ncs), + }, + } +} + +func testVerifier(pool *cert.CAPool) CertVerifier { + return func(c cert.Certificate) (*cert.CachedCertificate, error) { + return pool.VerifyCertificate(time.Now(), c) + } +} + +func newTestMachine( + t *testing.T, + cs *testCertState, + verifier CertVerifier, + initiator bool, + localIndex uint32, +) *Machine { + t.Helper() + m, err := NewMachine( + cs.version, cs.getCredential, + verifier, func() (uint32, error) { return localIndex, nil }, + initiator, header.HandshakeIXPSK0, + ) + require.NoError(t, err) + return m +} + +func initiateHandshake( + t *testing.T, + initCS *testCertState, initVerifier CertVerifier, + respCS *testCertState, respVerifier CertVerifier, +) (initM, respM *Machine, respResult *Result, resp []byte, err error) { + t.Helper() + initM = newTestMachine(t, initCS, initVerifier, true, 100) + msg1, merr := initM.Initiate(nil) + require.NoError(t, merr) + + respM = newTestMachine(t, respCS, respVerifier, false, 200) + resp, respResult, err = respM.ProcessPacket(nil, msg1) + return +} + +func doFullHandshake( + t *testing.T, initCS, respCS *testCertState, caPool *cert.CAPool, +) (initResult, respResult *Result) { + t.Helper() + v := testVerifier(caPool) + + initM := newTestMachine(t, initCS, v, true, 1000) + respM := newTestMachine(t, respCS, v, false, 2000) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + resp, respResult, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + require.NotNil(t, respResult) + require.NotEmpty(t, resp) + + _, initResult, err = initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, initResult) + + return initResult, respResult +} diff --git a/handshake/machine.go b/handshake/machine.go new file mode 100644 index 00000000..25ed3a5a --- /dev/null +++ b/handshake/machine.go @@ -0,0 +1,444 @@ +package handshake + +import ( + "bytes" + "fmt" + "slices" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/header" +) + +// IndexAllocator is called by the Machine to allocate a local index for the +// handshake. It is called at most once, when the first outgoing message that +// carries a payload is built. +// +// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning +// "no index assigned" on the wire and in the payload-presence checks. If an +// allocator ever returned 0, a legitimate handshake's payload could be +// indistinguishable from an empty one and would be rejected. +type IndexAllocator func() (uint32, error) + +// CertVerifier is called by the Machine after reconstructing the peer's +// certificate from the handshake. The verifier performs all validation +// (CA trust, expiry, policy checks, allow lists). +type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error) + +// Result contains the results of a successful handshake. +// Returned by ProcessPacket when the handshake is complete. +type Result struct { + EKey *noise.CipherState + DKey *noise.CipherState + MyCert cert.Certificate + RemoteCert *cert.CachedCertificate + RemoteIndex uint32 + LocalIndex uint32 + HandshakeTime uint64 + MessageIndex uint64 // number of messages exchanged during the handshake + Initiator bool +} + +// Machine drives a Noise handshake through N messages. It handles Noise +// protocol operations, certificate reconstruction, and payload encoding. +// Certificate validation is delegated to the caller via CertVerifier. +// +// A Machine is not safe for concurrent use. The caller must ensure that +// Initiate and ProcessPacket are not called concurrently. +// +// Error contract: when ProcessPacket or Initiate returns an error, callers +// must check Failed() to decide what to do next. If Failed() is false the +// underlying noise state was not advanced (the packet was rejected before +// ReadMessage took effect, or the rejection is non-fatal like a stale +// retransmit) and the Machine can accept another packet. If Failed() is +// true the Machine is unrecoverable and the caller must abandon it. +type Machine struct { + hs *noise.HandshakeState + getCred GetCredentialFunc + allocIndex IndexAllocator + verifier CertVerifier + result *Result + msgs []msgFlags + myVersion cert.Version + subtype header.MessageSubType + indexAllocated bool + remoteCertSet bool + payloadSet bool + failed bool +} + +// NewMachine creates a handshake state machine. The subtype determines both +// the noise pattern and the per-message content layout. The credential for +// `version` is fetched via getCred and used to seed the noise.HandshakeState. +// IndexAllocator is called lazily when the first outgoing payload is built. +func NewMachine( + version cert.Version, + getCred GetCredentialFunc, + verifier CertVerifier, + allocIndex IndexAllocator, + initiator bool, + subtype header.MessageSubType, +) (*Machine, error) { + info, err := subtypeInfoFor(subtype) + if err != nil { + return nil, err + } + + cred := getCred(version) + if cred == nil { + return nil, fmt.Errorf("%w: %v", ErrNoCredential, version) + } + + hs, err := cred.buildHandshakeState(initiator, info.pattern) + if err != nil { + return nil, fmt.Errorf("build noise state: %w", err) + } + + return &Machine{ + hs: hs, + subtype: subtype, + msgs: info.msgs, + getCred: getCred, + allocIndex: allocIndex, + verifier: verifier, + myVersion: version, + result: &Result{ + Initiator: initiator, + }, + }, nil +} + +// Failed returns true if the Machine is in an unrecoverable state. +func (m *Machine) Failed() bool { + return m.failed +} + +// Subtype returns the handshake subtype this Machine was built for. +func (m *Machine) Subtype() header.MessageSubType { + return m.subtype +} + +// MessageIndex returns the noise handshake message index, which equals the +// wire counter of the most recently sent or received message. +func (m *Machine) MessageIndex() int { + return m.hs.MessageIndex() +} + +// requireComplete checks that both a peer cert and payload have been received. +// Marks the machine as failed if not. +func (m *Machine) requireComplete() error { + if !m.payloadSet || !m.remoteCertSet { + m.failed = true + return ErrIncompleteHandshake + } + return nil +} + +// myMsgFlags returns the flags for the current outgoing message. +func (m *Machine) myMsgFlags() msgFlags { + idx := m.hs.MessageIndex() + if idx < len(m.msgs) { + return m.msgs[idx] + } + return msgFlags{} +} + +// peerMsgFlags returns the flags for the message we just read. +func (m *Machine) peerMsgFlags() msgFlags { + idx := m.hs.MessageIndex() - 1 + if idx >= 0 && idx < len(m.msgs) { + return m.msgs[idx] + } + return msgFlags{} +} + +// Initiate produces the first handshake message. Only valid for initiators, +// and must be called exactly once before ProcessPacket. +// +// out is a destination buffer the message is appended to and returned. Pass +// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g. +// buf[:0]) with sufficient capacity to avoid allocation. +// +// An error return may not indicate a fatal condition, check Failed() to +// determine if the Machine can still be used. +func (m *Machine) Initiate(out []byte) ([]byte, error) { + if m.failed { + return nil, ErrMachineFailed + } + if !m.result.Initiator { + m.failed = true + return nil, ErrInitiateOnResponder + } + if m.hs.MessageIndex() != 0 { + m.failed = true + return nil, ErrInitiateAlreadyCalled + } + + // At MessageIndex=0 with RemoteIndex still zero, buildResponse produces + // header counter 1 and remote index 0, which is what the initial message needs. + out, _, _, err := m.buildResponse(out) + if err != nil { + m.failed = true + return nil, err + } + return out, nil +} + +// ProcessPacket handles an incoming handshake message. It advances the Noise +// state, validates the peer certificate via the verifier, and optionally +// produces a response. +// +// out is a destination buffer the response is appended to and returned. Pass +// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g. +// buf[:0]) with sufficient capacity to avoid allocation. The returned slice +// is nil when no outgoing message is produced (handshake complete on this +// side, or final message of a multi-message pattern). +// +// Returns a non-nil Result when the handshake is complete. +// An error return may not indicate a fatal condition, check Failed() to +// determine if the Machine can still be used. +func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) { + if m.failed { + return nil, nil, ErrMachineFailed + } + if len(packet) < header.Len { + return nil, nil, ErrPacketTooShort + } + // Reject packets whose subtype doesn't match the one this Machine was + // built for. A pending handshake that suddenly receives a different + // subtype on its index is either a stray packet that matched by chance + // or a peer protocol violation; drop it without failing the Machine so + // the legitimate retransmit can still complete. + if header.MessageSubType(packet[1]) != m.subtype { + return nil, nil, ErrSubtypeMismatch + } + if m.result.Initiator && m.hs.MessageIndex() == 0 { + m.failed = true + return nil, nil, ErrInitiateNotCalled + } + + // The (eKey, dKey) ordering here is correct for IX, where the initiator + // completes the handshake by reading the responder's stage-2 message. + // noise returns (cs1, cs2) where cs1 is the initiator->responder cipher. + // For 3-message patterns where a responder finishes by reading the final + // message, this ordering would be wrong; revisit when XX/pqIX lands. + msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:]) + if err != nil { + // Noise ReadMessage failed. The noise library checkpoints and rolls back + // on failure, so the Machine is still alive. The caller can retry with + // a different packet. + return nil, nil, fmt.Errorf("noise ReadMessage: %w", err) + } + + // From here on, noise state has advanced. Any error is fatal. + flags := m.peerMsgFlags() + + if err := m.processPayload(msg, flags); err != nil { + return nil, nil, err + } + + // If ReadMessage derived keys, the handshake is complete. Noise should + // always produce both keys together; asymmetry is a protocol invariant + // violation. + if eKey != nil || dKey != nil { + if eKey == nil || dKey == nil { + m.failed = true + return nil, nil, ErrAsymmetricCipherKeys + } + if err := m.requireComplete(); err != nil { + return nil, nil, err + } + return nil, m.completed(eKey, dKey), nil + } + + // ReadMessage didn't complete, produce the next outgoing message + out, dk, ek, err := m.buildResponse(out) + if err != nil { + m.failed = true + return nil, nil, err + } + + if ek != nil || dk != nil { + if ek == nil || dk == nil { + m.failed = true + return nil, nil, ErrAsymmetricCipherKeys + } + if err := m.requireComplete(); err != nil { + return nil, nil, err + } + return out, m.completed(ek, dk), nil + } + + return out, nil, nil +} + +func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result { + m.result.EKey = eKey + m.result.DKey = dKey + m.result.MessageIndex = uint64(m.hs.MessageIndex()) + return m.result +} + +func (m *Machine) processPayload(msg []byte, flags msgFlags) error { + if len(msg) == 0 { + if flags.expectsPayload || flags.expectsCert { + m.failed = true + return ErrMissingContent + } + return nil + } + + payload, err := UnmarshalPayload(msg) + if err != nil { + m.failed = true + return fmt.Errorf("unmarshal handshake: %w", err) + } + + // Assert the payload contains exactly what we expect + hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0 + if hasPayloadData != flags.expectsPayload { + m.failed = true + return ErrUnexpectedContent + } + + hasCertData := len(payload.Cert) > 0 + if hasCertData != flags.expectsCert { + m.failed = true + return ErrUnexpectedContent + } + + // Process payload + if flags.expectsPayload { + if m.result.Initiator { + m.result.RemoteIndex = payload.ResponderIndex + } else { + m.result.RemoteIndex = payload.InitiatorIndex + } + m.result.HandshakeTime = payload.Time + m.payloadSet = true + } + + // Process certificate + if flags.expectsCert { + if err := m.validateCert(payload); err != nil { + return err + } + } + + return nil +} + +func (m *Machine) validateCert(payload Payload) error { + cred := m.getCred(m.myVersion) + if cred == nil { + m.failed = true + return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion) + } + rc, err := cert.Recombine( + cert.Version(payload.CertVersion), + payload.Cert, + m.hs.PeerStatic(), + cred.Cert.Curve(), + ) + if err != nil { + m.failed = true + return fmt.Errorf("recombine cert: %w", err) + } + + if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) { + m.failed = true + return ErrPublicKeyMismatch + } + + // Version negotiation, if the peer sent a different version and we have it, switch + if rc.Version() != m.myVersion { + if m.getCred(rc.Version()) != nil { + m.myVersion = rc.Version() + } + } + + verified, err := m.verifier(rc) + if err != nil { + m.failed = true + return fmt.Errorf("verify cert: %w", err) + } + + m.result.RemoteCert = verified + m.remoteCertSet = true + return nil +} + +func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) { + if !flags.expectsPayload && !flags.expectsCert { + return nil, nil + } + + var p Payload + if flags.expectsPayload { + if !m.indexAllocated { + index, err := m.allocIndex() + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err) + } + m.result.LocalIndex = index + m.indexAllocated = true + } + + if m.result.Initiator { + p.InitiatorIndex = m.result.LocalIndex + } else { + p.ResponderIndex = m.result.LocalIndex + p.InitiatorIndex = m.result.RemoteIndex + } + p.Time = uint64(time.Now().UnixNano()) + } + if flags.expectsCert { + cred := m.getCred(m.myVersion) + if cred == nil { + return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion) + } + p.Cert = cred.Bytes + p.CertVersion = uint32(cred.Cert.Version()) + m.result.MyCert = cred.Cert + } + + return MarshalPayload(nil, p), nil +} + +func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) { + flags := m.myMsgFlags() + hsBytes, err := m.marshalOutgoing(flags) + if err != nil { + return nil, nil, nil, err + } + + // Extend out by header.Len to make room for the header. slices.Grow is a + // no-op when the cap is already sufficient (the zero-copy case where the + // caller passed a pre-sized buffer). header.Encode overwrites the new + // bytes, so they don't need to be zeroed. + start := len(out) + out = slices.Grow(out, header.Len)[:start+header.Len] + header.Encode( + out[start:], + header.Version, header.Handshake, m.subtype, + m.result.RemoteIndex, + uint64(m.hs.MessageIndex()+1), + ) + + // noise.WriteMessage appends the encrypted handshake message to out, + // reusing capacity when present. + // + // The (dKey, eKey) ordering here is correct for IX, where the responder + // completes the handshake by writing the stage-2 message. noise returns + // (cs1, cs2) where cs1 is the initiator->responder cipher (which is the + // responder's decrypt key). For 3-message patterns where an initiator + // finishes by writing the final message, this ordering would be wrong; + // revisit when XX/pqIX lands. + out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes) + if err != nil { + return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err) + } + + return out, dKey, eKey, nil +} diff --git a/handshake/machine_test.go b/handshake/machine_test.go new file mode 100644 index 00000000..722a39e1 --- /dev/null +++ b/handshake/machine_test.go @@ -0,0 +1,662 @@ +package handshake + +import ( + "net/netip" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + ct "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/noiseutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMachineIXHappyPath(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + initCS := newTestCertState(t, ca, caKey, "initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initR, respR := doFullHandshake(t, initCS, respCS, caPool) + + assert.Equal(t, "responder", initR.RemoteCert.Certificate.Name()) + assert.Equal(t, "initiator", respR.RemoteCert.Certificate.Name()) + + assert.Equal(t, uint32(1000), initR.LocalIndex) + assert.Equal(t, uint32(2000), initR.RemoteIndex) + assert.Equal(t, uint32(2000), respR.LocalIndex) + assert.Equal(t, uint32(1000), respR.RemoteIndex) + + assert.Equal(t, uint64(2), initR.MessageIndex, "IX has 2 messages") + assert.Equal(t, uint64(2), respR.MessageIndex, "IX has 2 messages") + + ct1, err := initR.EKey.Encrypt(nil, nil, []byte("hello")) + require.NoError(t, err) + pt1, err := respR.DKey.Decrypt(nil, nil, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("hello"), pt1) + + ct2, err := respR.EKey.Encrypt(nil, nil, []byte("world")) + require.NoError(t, err) + pt2, err := initR.DKey.Decrypt(nil, nil, ct2) + require.NoError(t, err) + assert.Equal(t, []byte("world"), pt2) +} + +func TestMachineInitiateErrors(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("initiate on responder", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + _, err := m.Initiate(nil) + require.ErrorIs(t, err, ErrInitiateOnResponder) + assert.True(t, m.Failed()) + }) + + t.Run("initiate called twice", func(t *testing.T) { + m := newTestMachine(t, cs, v, true, 100) + _, err := m.Initiate(nil) + require.NoError(t, err) + _, err = m.Initiate(nil) + require.ErrorIs(t, err, ErrInitiateAlreadyCalled) + assert.True(t, m.Failed()) + }) + + t.Run("process packet before initiate on initiator", func(t *testing.T) { + m := newTestMachine(t, cs, v, true, 100) + _, _, err := m.ProcessPacket(nil, make([]byte, 100)) + require.ErrorIs(t, err, ErrInitiateNotCalled) + assert.True(t, m.Failed()) + }) + + t.Run("calling failed machine", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + _, err := m.Initiate(nil) // fails: responder + require.Error(t, err) + _, err = m.Initiate(nil) // fails: already failed + require.ErrorIs(t, err, ErrMachineFailed) + }) +} + +func TestMachineProcessPacketErrors(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("packet too short", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + _, _, err := m.ProcessPacket(nil, []byte{1, 2, 3}) + require.ErrorIs(t, err, ErrPacketTooShort) + assert.False(t, m.Failed(), "short packet should not kill machine") + }) + + t.Run("noise decryption failure is recoverable", func(t *testing.T) { + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + initM := newTestMachine(t, initCS, v, true, 100) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + respM := newTestMachine(t, cs, v, false, 200) + resp, _, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + + corrupted := make([]byte, len(resp)) + copy(corrupted, resp) + for i := header.Len; i < len(corrupted); i++ { + corrupted[i] ^= 0xff + } + _, _, err = initM.ProcessPacket(nil, corrupted) + require.Error(t, err) + assert.False(t, initM.Failed(), "noise failure should be recoverable") + + // And the machine should still complete a real handshake afterward. + _, result, err := initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, result, "initiator should complete on the legitimate response") + }) + + t.Run("invalid cert is fatal", func(t *testing.T) { + otherCA, _, otherCAKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + otherCS := newTestCertState(t, otherCA, otherCAKey, "other", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initM := newTestMachine(t, otherCS, testVerifier(ct.NewTestCAPool(otherCA)), true, 100) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + respM := newTestMachine(t, cs, v, false, 200) + _, _, err = respM.ProcessPacket(nil, msg1) + require.Error(t, err) + assert.True(t, respM.Failed(), "cert validation failure should kill machine") + }) + + t.Run("subtype mismatch is recoverable", func(t *testing.T) { + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + initM := newTestMachine(t, initCS, v, true, 100) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + // Mutate the subtype byte (offset 1 in the header) to a value the + // responder Machine wasn't built for. + bad := make([]byte, len(msg1)) + copy(bad, msg1) + bad[1] = 0xff + + respM := newTestMachine(t, cs, v, false, 200) + _, _, err = respM.ProcessPacket(nil, bad) + require.ErrorIs(t, err, ErrSubtypeMismatch) + assert.False(t, respM.Failed(), "subtype mismatch should not kill the machine") + + // And the machine should still complete a real handshake afterward. + resp, result, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + require.NotNil(t, result, "responder should complete on the legitimate stage-1 packet") + assert.NotEmpty(t, resp, "responder should produce a stage-2 reply") + }) +} + +// TestMachineProcessPayload exercises processPayload's internal validation +// directly. Most of these failure modes can't be reached black-box once the +// subtype check at the top of ProcessPacket gates external callers, so we +// drive them by hand here for coverage. +func TestMachineProcessPayload(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("empty message with expects fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.processPayload(nil, msgFlags{expectsPayload: true, expectsCert: true}) + require.ErrorIs(t, err, ErrMissingContent) + assert.True(t, m.Failed()) + }) + + t.Run("empty message with no expects passes", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.processPayload(nil, msgFlags{}) + require.NoError(t, err) + assert.False(t, m.Failed()) + }) + + t.Run("malformed protobuf is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.processPayload([]byte{0xff, 0xff, 0xff}, msgFlags{expectsPayload: true, expectsCert: true}) + require.Error(t, err) + assert.True(t, m.Failed()) + }) + + t.Run("unexpected payload data is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + // A payload with index data when none was expected. + bytes := MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}) + err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false}) + require.ErrorIs(t, err, ErrUnexpectedContent) + assert.True(t, m.Failed()) + }) + + t.Run("unexpected cert data is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + // A payload with cert when none was expected. + bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}) + err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false}) + require.ErrorIs(t, err, ErrUnexpectedContent) + assert.True(t, m.Failed()) + }) + + t.Run("missing payload data when expected is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + // Cert present, but no index/time fields. + bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}) + err := m.processPayload(bytes, msgFlags{expectsPayload: true, expectsCert: true}) + require.ErrorIs(t, err, ErrUnexpectedContent) + assert.True(t, m.Failed()) + }) +} + +// TestMachineRequireComplete checks the fail-on-incomplete-handshake path +// directly. Like processPayload above this isn't reachable from a normal IX +// flow, so we drive it by hand. +func TestMachineRequireComplete(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("missing both fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.requireComplete() + require.ErrorIs(t, err, ErrIncompleteHandshake) + assert.True(t, m.Failed()) + }) + + t.Run("payload only fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + m.payloadSet = true + err := m.requireComplete() + require.ErrorIs(t, err, ErrIncompleteHandshake) + assert.True(t, m.Failed()) + }) + + t.Run("cert only fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + m.remoteCertSet = true + err := m.requireComplete() + require.ErrorIs(t, err, ErrIncompleteHandshake) + assert.True(t, m.Failed()) + }) + + t.Run("both set passes", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + m.payloadSet = true + m.remoteCertSet = true + err := m.requireComplete() + require.NoError(t, err) + assert.False(t, m.Failed()) + }) +} + +func TestMachineAESCipher(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + initCS := newTestCertStateWithCipher( + t, ca, caKey, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + noiseutil.CipherAESGCM, + ) + respCS := newTestCertStateWithCipher( + t, ca, caKey, "resp", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + noiseutil.CipherAESGCM, + ) + + initR, respR := doFullHandshake(t, initCS, respCS, caPool) + + ct1, err := initR.EKey.Encrypt(nil, nil, []byte("works")) + require.NoError(t, err) + pt1, err := respR.DKey.Decrypt(nil, nil, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("works"), pt1) + + ct2, err := respR.EKey.Encrypt(nil, nil, []byte("back")) + require.NoError(t, err) + pt2, err := initR.DKey.Decrypt(nil, nil, ct2) + require.NoError(t, err) + assert.Equal(t, []byte("back"), pt2) +} + +func TestResultFields(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initR, respR := doFullHandshake(t, initCS, respCS, caPool) + + assert.True(t, initR.Initiator) + assert.False(t, respR.Initiator) + assert.NotZero(t, initR.HandshakeTime) + assert.NotZero(t, respR.HandshakeTime) + assert.NotNil(t, initR.RemoteCert) + assert.NotNil(t, respR.RemoteCert) +} + +func TestMachineBufferReuse(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + v := testVerifier(caPool) + + initM := newTestMachine(t, initCS, v, true, 1000) + respM := newTestMachine(t, respCS, v, false, 2000) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + t.Run("response writes into provided buffer", func(t *testing.T) { + buf := make([]byte, 0, 4096) + resp, result, err := respM.ProcessPacket(buf, msg1) + require.NoError(t, err) + require.NotNil(t, result) + + assert.NotEmpty(t, resp, "response should have content") + assert.Equal(t, &buf[:1][0], &resp[:1][0], + "response should reuse the provided buffer's backing array") + }) + + t.Run("initiate writes into provided buffer", func(t *testing.T) { + initM2 := newTestMachine(t, initCS, v, true, 3000) + buf := make([]byte, 0, 4096) + msg, err := initM2.Initiate(buf) + require.NoError(t, err) + + assert.NotEmpty(t, msg, "initiate should have content") + assert.Equal(t, &buf[:1][0], &msg[:1][0], + "initiate should reuse the provided buffer's backing array") + }) + + t.Run("nil out still works", func(t *testing.T) { + initM2 := newTestMachine(t, initCS, v, true, 4000) + respM2 := newTestMachine(t, respCS, v, false, 5000) + + msg1, err := initM2.Initiate(nil) + require.NoError(t, err) + + resp, _, err := respM2.ProcessPacket(nil, msg1) + require.NoError(t, err) + + out, result, err := initM2.ProcessPacket(nil, resp) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Nil(t, out, "initiator should have no response for IX msg2") + }) +} + +func TestMachineMsgIndexTracking(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + v := testVerifier(caPool) + + initM := newTestMachine(t, initCS, v, true, 100) + respM := newTestMachine(t, respCS, v, false, 200) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + resp1, result1, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + assert.NotNil(t, result1) + + _, result2, err := initM.ProcessPacket(nil, resp1) + require.NoError(t, err) + assert.NotNil(t, result2) +} + +func TestMachineThreeMessagePattern(t *testing.T) { + registerTestXXInfo(t) + + // Use HandshakeXX (3 messages) to verify the Machine handles multi-message + // patterns correctly. XX flow: + // msg1 (I->R): [E] - payload only, no cert + // msg2 (R->I): [E, ee, S, es] - payload + cert + // msg3 (I->R): [S, se] - cert only (no payload, not first two) + + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + v := testVerifier(caPool) + + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initM, err := NewMachine( + cert.Version2, + initCS.getCredential, v, + func() (uint32, error) { return 1000, nil }, + true, header.HandshakeXXPSK0, + ) + require.NoError(t, err) + + respM, err := NewMachine( + cert.Version2, + respCS.getCredential, v, + func() (uint32, error) { return 2000, nil }, + false, header.HandshakeXXPSK0, + ) + require.NoError(t, err) + + // msg1: initiator -> responder (E only, no cert) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + assert.NotEmpty(t, msg1) + + // Responder processes msg1, should not complete yet, should produce msg2 + msg2, result, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + assert.Nil(t, result, "XX should not complete on msg1") + assert.NotEmpty(t, msg2, "responder should produce msg2") + + // Initiator processes msg2: gets responder's cert, produces msg3, and + // completes (WriteMessage for msg3 derives keys) + msg3, initResult, err := initM.ProcessPacket(nil, msg2) + require.NoError(t, err) + require.NotNil(t, initResult, "XX initiator should complete after reading msg2 and writing msg3") + assert.NotEmpty(t, msg3, "initiator should produce msg3") + assert.Equal(t, "resp", initResult.RemoteCert.Certificate.Name()) + + // Responder processes msg3: gets initiator's cert and completes + _, respResult, err := respM.ProcessPacket(nil, msg3) + require.NoError(t, err) + require.NotNil(t, respResult, "XX responder should complete on msg3") + assert.Equal(t, "init", respResult.RemoteCert.Certificate.Name()) + + assert.Equal(t, uint64(3), initResult.MessageIndex, "XX has 3 messages") + assert.Equal(t, uint64(3), respResult.MessageIndex, "XX has 3 messages") + + // Verify keys work + ct1, err := initResult.EKey.Encrypt(nil, nil, []byte("three messages")) + require.NoError(t, err) + pt1, err := respResult.DKey.Decrypt(nil, nil, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("three messages"), pt1) +} + +// NOTE: ErrIncompleteHandshake is tested implicitly. It can't be triggered with +// IX since the cert is always in the payload. A 3-message pattern test (HybridIX) +// should exercise the case where cert arrives in msg3 and verify that completing +// without it fails. + +func TestMachineExpiredCert(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, + time.Now().Add(-24*time.Hour), time.Now().Add(24*time.Hour), + nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + expCert, _, expKeyPEM, _ := ct.NewTestCert( + cert.Version2, cert.Curve_CURVE25519, ca, caKey, + "expired", time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour), + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, nil, nil, + ) + expKey, _, _, err := cert.UnmarshalPrivateKeyFromPEM(expKeyPEM) + require.NoError(t, err) + expHsBytes, err := expCert.MarshalForHandshakes() + require.NoError(t, err) + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + + expiredCS := &testCertState{ + version: cert.Version2, + creds: map[cert.Version]*Credential{ + cert.Version2: NewCredential(expCert, expHsBytes, expKey, ncs), + }, + } + + respCS := newTestCertState( + t, ca, caKey, "responder", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + ) + + _, respM, _, _, err := initiateHandshake( + t, expiredCS, testVerifier(caPool), + respCS, testVerifier(caPool), + ) + require.ErrorContains(t, err, "verify cert") + assert.True(t, respM.Failed()) +} + +func TestMachineNoCertNetworks(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + caHsBytes, err := ca.MarshalForHandshakes() + require.NoError(t, err) + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + + noNetCS := &testCertState{ + version: cert.Version2, + creds: map[cert.Version]*Credential{ + cert.Version2: NewCredential(ca, caHsBytes, caKey, ncs), + }, + } + + respCS := newTestCertState( + t, ca, caKey, "responder", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + ) + + _, respM, _, _, err := initiateHandshake( + t, noNetCS, testVerifier(caPool), + respCS, testVerifier(caPool), + ) + require.Error(t, err) + assert.True(t, respM.Failed()) +} + +func TestMachineDifferentCAs(t *testing.T) { + ca1, _, caKey1, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + ca2, _, caKey2, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + + initCS := newTestCertState( + t, ca1, caKey1, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + ) + respCS := newTestCertState( + t, ca2, caKey2, "resp", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + ) + + _, respM, _, _, err := initiateHandshake( + t, initCS, testVerifier(ct.NewTestCAPool(ca1)), + respCS, testVerifier(ct.NewTestCAPool(ca2)), + ) + require.ErrorContains(t, err, "verify cert") + assert.True(t, respM.Failed()) +} + +func TestMachineVersionNegotiation(t *testing.T) { + ca1, _, caKey1, _ := ct.NewTestCaCert( + cert.Version1, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + ca2, _, caKey2, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca1, ca2) + + makeMultiVersionResp := func(t *testing.T) *testCertState { + t.Helper() + respCertV1, _, respKeyPEM, _ := ct.NewTestCert( + cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp", + ca1.NotBefore(), ca1.NotAfter(), + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil, + ) + respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM) + respCertV2, _ := ct.NewTestCertDifferentVersion(respCertV1, cert.Version2, ca2, caKey2) + respHsV1, _ := respCertV1.MarshalForHandshakes() + respHsV2, _ := respCertV2.MarshalForHandshakes() + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + return &testCertState{ + version: cert.Version1, + creds: map[cert.Version]*Credential{ + cert.Version1: NewCredential(respCertV1, respHsV1, respKey, ncs), + cert.Version2: NewCredential(respCertV2, respHsV2, respKey, ncs), + }, + } + } + + t.Run("responder matches initiator version", func(t *testing.T) { + initCS := newTestCertState( + t, ca2, caKey2, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + ) + respCS := makeMultiVersionResp(t) + v := testVerifier(caPool) + + initM, _, respResult, resp, err := initiateHandshake( + t, initCS, v, + respCS, v, + ) + require.NoError(t, err) + require.NotNil(t, respResult) + + assert.Equal(t, cert.Version2, respResult.MyCert.Version(), + "responder should negotiate to initiator's version") + + _, initResult, err := initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, initResult) + assert.Equal(t, cert.Version2, initResult.RemoteCert.Certificate.Version(), + "initiator should see V2 cert from responder") + }) + + t.Run("responder keeps version when no match available", func(t *testing.T) { + initCS := newTestCertState( + t, ca2, caKey2, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + ) + + respCert, _, respKeyPEM, _ := ct.NewTestCert( + cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp", + ca1.NotBefore(), ca1.NotAfter(), + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil, + ) + respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM) + respHs, _ := respCert.MarshalForHandshakes() + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + respCS := &testCertState{ + version: cert.Version1, + creds: map[cert.Version]*Credential{ + cert.Version1: NewCredential(respCert, respHs, respKey, ncs), + }, + } + + v := testVerifier(caPool) + _, _, respResult, _, err := initiateHandshake( + t, initCS, v, + respCS, v, + ) + require.NoError(t, err) + require.NotNil(t, respResult) + + assert.Equal(t, cert.Version1, respResult.MyCert.Version(), + "responder should keep V1 when V2 not available") + }) +} diff --git a/handshake/patterns.go b/handshake/patterns.go new file mode 100644 index 00000000..a0cc1a70 --- /dev/null +++ b/handshake/patterns.go @@ -0,0 +1,54 @@ +package handshake + +import ( + "fmt" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/header" +) + +// msgFlags tracks what application data a handshake message carries. +type msgFlags struct { + expectsPayload bool // message carries indexes and time + expectsCert bool // message carries the certificate +} + +// subtypeInfo bundles the noise pattern with the per-message flags for a +// given handshake subtype. +type subtypeInfo struct { + pattern noise.HandshakePattern + msgs []msgFlags +} + +// subtypeInfos defines the noise pattern and message content layout for each +// handshake subtype. +var subtypeInfos = map[header.MessageSubType]subtypeInfo{ + // IX: 2 messages, both carry payload and cert + header.HandshakeIXPSK0: { + pattern: noise.HandshakeIX, + msgs: []msgFlags{ + {expectsPayload: true, expectsCert: true}, + {expectsPayload: true, expectsCert: true}, + }, + }, + + // XX: 3 messages + // msg1 (I->R): payload only + // msg2 (R->I): payload + cert + // msg3 (I->R): cert only + //header.HandshakeXXPSK0: { + // pattern: noise.HandshakeXX, + // msgs: []msgFlags{ + // {expectsPayload: true, expectsCert: false}, + // {expectsPayload: true, expectsCert: true}, + // {expectsPayload: false, expectsCert: true}, + // }, + //}, +} + +func subtypeInfoFor(subtype header.MessageSubType) (subtypeInfo, error) { + if info, ok := subtypeInfos[subtype]; ok { + return info, nil + } + return subtypeInfo{}, fmt.Errorf("%w: %d", ErrUnknownSubtype, subtype) +} diff --git a/handshake/patterns_test.go b/handshake/patterns_test.go new file mode 100644 index 00000000..d6207e00 --- /dev/null +++ b/handshake/patterns_test.go @@ -0,0 +1,63 @@ +package handshake + +import ( + "testing" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/header" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSubtypeInfo(t *testing.T) { + t.Run("IX", func(t *testing.T) { + info, err := subtypeInfoFor(header.HandshakeIXPSK0) + require.NoError(t, err) + assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name) + require.Len(t, info.msgs, 2) + // msg1: payload + cert + assert.True(t, info.msgs[0].expectsPayload) + assert.True(t, info.msgs[0].expectsCert) + // msg2: payload + cert + assert.True(t, info.msgs[1].expectsPayload) + assert.True(t, info.msgs[1].expectsCert) + }) + + t.Run("XX", func(t *testing.T) { + registerTestXXInfo(t) + info, err := subtypeInfoFor(header.HandshakeXXPSK0) + require.NoError(t, err) + assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name) + require.Len(t, info.msgs, 3) + // msg1: payload only + assert.True(t, info.msgs[0].expectsPayload) + assert.False(t, info.msgs[0].expectsCert) + // msg2: payload + cert + assert.True(t, info.msgs[1].expectsPayload) + assert.True(t, info.msgs[1].expectsCert) + // msg3: cert only + assert.False(t, info.msgs[2].expectsPayload) + assert.True(t, info.msgs[2].expectsCert) + }) + + t.Run("unknown subtype returns error", func(t *testing.T) { + _, err := subtypeInfoFor(99) + require.ErrorIs(t, err, ErrUnknownSubtype) + }) +} + +// registerTestXXInfo temporarily registers XX subtype info for testing. +func registerTestXXInfo(t *testing.T) { + t.Helper() + subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{ + pattern: noise.HandshakeXX, + msgs: []msgFlags{ + {expectsPayload: true, expectsCert: false}, + {expectsPayload: true, expectsCert: true}, + {expectsPayload: false, expectsCert: true}, + }, + } + t.Cleanup(func() { + delete(subtypeInfos, header.HandshakeXXPSK0) + }) +} diff --git a/handshake/payload.go b/handshake/payload.go new file mode 100644 index 00000000..4567fc0d --- /dev/null +++ b/handshake/payload.go @@ -0,0 +1,173 @@ +package handshake + +import ( + "errors" + "math" + + "google.golang.org/protobuf/encoding/protowire" +) + +var ( + errInvalidHandshakeMessage = errors.New("invalid handshake message") + errInvalidHandshakeDetails = errors.New("invalid handshake details") +) + +// Payload represents the decoded fields of a handshake message. +// Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}. +type Payload struct { + Cert []byte + InitiatorIndex uint32 + ResponderIndex uint32 + Time uint64 + CertVersion uint32 +} + +// Proto field numbers for NebulaHandshakeDetails +const ( + fieldCert = 1 // bytes + fieldInitiatorIndex = 2 // uint32 + fieldResponderIndex = 3 // uint32 + fieldTime = 5 // uint64 + fieldCertVersion = 8 // uint32 +) + +// MarshalPayload encodes a handshake payload in protobuf wire format compatible +// with NebulaHandshake{Details: NebulaHandshakeDetails{...}}. +// Returns out (which may be nil), with the marshalled Payload appended to it. +func MarshalPayload(out []byte, p Payload) []byte { + var details []byte + + if len(p.Cert) > 0 { + details = protowire.AppendTag(details, fieldCert, protowire.BytesType) + details = protowire.AppendBytes(details, p.Cert) + } + if p.InitiatorIndex != 0 { + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.InitiatorIndex)) + } + if p.ResponderIndex != 0 { + details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.ResponderIndex)) + } + if p.Time != 0 { + details = protowire.AppendTag(details, fieldTime, protowire.VarintType) + details = protowire.AppendVarint(details, p.Time) + } + if p.CertVersion != 0 { + details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.CertVersion)) + } + + out = protowire.AppendTag(out, 1, protowire.BytesType) + out = protowire.AppendBytes(out, details) + + return out +} + +// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message. +func UnmarshalPayload(b []byte) (Payload, error) { + var p Payload + + for len(b) > 0 { + num, typ, n := protowire.ConsumeTag(b) + if n < 0 { + return p, errInvalidHandshakeMessage + } + b = b[n:] + + switch { + case num == 1 && typ == protowire.BytesType: + details, n := protowire.ConsumeBytes(b) + if n < 0 { + return p, errInvalidHandshakeMessage + } + b = b[n:] + if err := unmarshalPayloadDetails(&p, details); err != nil { + return p, err + } + default: + n := protowire.ConsumeFieldValue(num, typ, b) + if n < 0 { + return p, errInvalidHandshakeMessage + } + b = b[n:] + } + } + + return p, nil +} + +func unmarshalPayloadDetails(p *Payload, b []byte) error { + for len(b) > 0 { + num, typ, n := protowire.ConsumeTag(b) + if n < 0 { + return errInvalidHandshakeDetails + } + b = b[n:] + + // For known field numbers, reject any non-matching wire type as a + // hard error rather than silently skipping. The caller will catch + // missing-field cases downstream, but a wire-type mismatch on a tag + // we know is a peer protocol violation worth flagging here. + // Repeated occurrences of a singular field follow proto3 last-wins. + switch num { + case fieldCert: + if typ != protowire.BytesType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeBytes(b) + if n < 0 { + return errInvalidHandshakeDetails + } + p.Cert = append([]byte(nil), v...) + b = b[n:] + case fieldInitiatorIndex: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.InitiatorIndex = uint32(v) + b = b[n:] + case fieldResponderIndex: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.ResponderIndex = uint32(v) + b = b[n:] + case fieldTime: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 { + return errInvalidHandshakeDetails + } + p.Time = v + b = b[n:] + case fieldCertVersion: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.CertVersion = uint32(v) + b = b[n:] + default: + n := protowire.ConsumeFieldValue(num, typ, b) + if n < 0 { + return errInvalidHandshakeDetails + } + b = b[n:] + } + } + return nil +} diff --git a/handshake/payload_test.go b/handshake/payload_test.go new file mode 100644 index 00000000..2ff3231c --- /dev/null +++ b/handshake/payload_test.go @@ -0,0 +1,361 @@ +package handshake + +import ( + "bytes" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" +) + +func TestPayloadRoundTrip(t *testing.T) { + t.Run("all fields set", func(t *testing.T) { + data := MarshalPayload(nil, Payload{ + Cert: []byte("test-cert-bytes"), + CertVersion: 2, + InitiatorIndex: 12345, + ResponderIndex: 67890, + Time: 1234567890, + }) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, []byte("test-cert-bytes"), got.Cert) + assert.Equal(t, uint32(12345), got.InitiatorIndex) + assert.Equal(t, uint32(67890), got.ResponderIndex) + assert.Equal(t, uint64(1234567890), got.Time) + assert.Equal(t, uint32(2), got.CertVersion) + }) + + t.Run("minimal fields", func(t *testing.T) { + data := MarshalPayload(nil, Payload{InitiatorIndex: 1}) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, uint32(1), got.InitiatorIndex) + assert.Equal(t, uint32(0), got.ResponderIndex) + assert.Equal(t, uint64(0), got.Time) + assert.Nil(t, got.Cert) + }) + + t.Run("empty payload", func(t *testing.T) { + data := MarshalPayload(nil, Payload{}) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, uint32(0), got.InitiatorIndex) + }) + + t.Run("large cert bytes", func(t *testing.T) { + bigCert := make([]byte, 4096) + for i := range bigCert { + bigCert[i] = byte(i % 256) + } + + data := MarshalPayload(nil, Payload{ + Cert: bigCert, + CertVersion: 2, + InitiatorIndex: 999, + }) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, bigCert, got.Cert) + assert.Equal(t, uint32(999), got.InitiatorIndex) + }) + + t.Run("append to existing buffer", func(t *testing.T) { + prefix := []byte("prefix") + data := MarshalPayload(prefix, Payload{InitiatorIndex: 42}) + + assert.Equal(t, []byte("prefix"), data[:6]) + + got, err := UnmarshalPayload(data[6:]) + require.NoError(t, err) + assert.Equal(t, uint32(42), got.InitiatorIndex) + }) +} + +func TestPayloadUnknownFields(t *testing.T) { + t.Run("unknown field in outer message is skipped", func(t *testing.T) { + // Marshal a normal payload then append an unknown field (field 99, varint) + data := MarshalPayload(nil, Payload{InitiatorIndex: 42}) + data = protowire.AppendTag(data, 99, protowire.VarintType) + data = protowire.AppendVarint(data, 12345) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + assert.Equal(t, uint32(42), got.InitiatorIndex) + }) + + t.Run("unknown field in details is skipped", func(t *testing.T) { + // Build details with a known field + unknown field + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 77) + // Unknown field 50, varint + details = protowire.AppendTag(details, 50, protowire.VarintType) + details = protowire.AppendVarint(details, 9999) + // Another known field after the unknown one + details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 88) + + // Wrap in outer message + var data []byte + data = protowire.AppendTag(data, 1, protowire.BytesType) + data = protowire.AppendBytes(data, details) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + assert.Equal(t, uint32(77), got.InitiatorIndex) + assert.Equal(t, uint32(88), got.ResponderIndex) + }) + + t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) { + // Fields 6 and 7 are reserved in the proto definition + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 100) + details = protowire.AppendTag(details, 6, protowire.VarintType) + details = protowire.AppendVarint(details, 1) + details = protowire.AppendTag(details, 7, protowire.VarintType) + details = protowire.AppendVarint(details, 2) + + var data []byte + data = protowire.AppendTag(data, 1, protowire.BytesType) + data = protowire.AppendBytes(data, details) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + assert.Equal(t, uint32(100), got.InitiatorIndex) + }) +} + +func TestPayloadBytesConsumed(t *testing.T) { + t.Run("all bytes consumed on valid input", func(t *testing.T) { + original := Payload{ + Cert: []byte("cert"), + CertVersion: 2, + InitiatorIndex: 100, + ResponderIndex: 200, + Time: 999, + } + data := MarshalPayload(nil, original) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + // Re-marshal and compare — proves we consumed and reproduced all fields + remarshaled := MarshalPayload(nil, got) + assert.Equal(t, data, remarshaled) + }) +} + +// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope +// so UnmarshalPayload can reach unmarshalPayloadDetails. +func wrapDetails(details []byte) []byte { + var out []byte + out = protowire.AppendTag(out, 1, protowire.BytesType) + out = protowire.AppendBytes(out, details) + return out +} + +func TestPayloadUnmarshalErrors(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + got, err := UnmarshalPayload(nil) + require.NoError(t, err) + assert.Equal(t, uint32(0), got.InitiatorIndex) + }) + + t.Run("truncated outer tag", func(t *testing.T) { + _, err := UnmarshalPayload([]byte{0x80}) + assert.Error(t, err) + }) + + t.Run("truncated outer details field", func(t *testing.T) { + _, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05}) + assert.Error(t, err) + }) + + t.Run("truncated outer unknown field", func(t *testing.T) { + // Valid tag for unknown field 99 varint, but no value follows + var data []byte + data = protowire.AppendTag(data, 99, protowire.VarintType) + _, err := UnmarshalPayload(data) + assert.Error(t, err) + }) + + t.Run("truncated details tag", func(t *testing.T) { + _, err := UnmarshalPayload(wrapDetails([]byte{0x80})) + assert.Error(t, err) + }) + + t.Run("truncated cert bytes", func(t *testing.T) { + // Field 1 (cert), bytes type, length 10 but only 2 bytes + var details []byte + details = protowire.AppendTag(details, fieldCert, protowire.BytesType) + details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated initiator index varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = append(details, 0x80) // incomplete varint + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated responder index varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) + details = append(details, 0x80) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated time varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldTime, protowire.VarintType) + details = append(details, 0x80) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated cert version varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) + details = append(details, 0x80) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated unknown field in details", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, 50, protowire.VarintType) + details = append(details, 0x80) // incomplete varint + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("cert with wrong wire type rejected", func(t *testing.T) { + // fieldCert as Varint instead of Bytes. + var details []byte + details = protowire.AppendTag(details, fieldCert, protowire.VarintType) + details = protowire.AppendVarint(details, 42) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("initiator index with wrong wire type rejected", func(t *testing.T) { + // fieldInitiatorIndex as Bytes instead of Varint. + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType) + details = protowire.AppendBytes(details, []byte{1, 2, 3}) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("time with wrong wire type rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldTime, protowire.BytesType) + details = protowire.AppendBytes(details, []byte{1, 2, 3}) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("cert version with wrong wire type rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType) + details = protowire.AppendBytes(details, []byte{1, 2, 3}) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) { + // Per proto3, multiple instances of a singular field are accepted and + // the last value wins. We keep this behavior so that peers using + // alternative encoders aren't rejected. + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 1) + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 42) + got, err := UnmarshalPayload(wrapDetails(details)) + require.NoError(t, err) + assert.Equal(t, uint32(42), got.InitiatorIndex) + }) + + t.Run("initiator index varint overflow rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, math.MaxUint32+1) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("cert version varint overflow rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) + details = protowire.AppendVarint(details, math.MaxUint32+1) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + +} + +// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it +// never panics, and for any input that parses cleanly, that re-marshal + +// re-parse is a fix-point. Inputs come from an authenticated peer (post- +// noise-decrypt), so the threat model is "valid peer behaving arbitrarily," +// not "unauthenticated injection." +func FuzzPayload(f *testing.F) { + // Seed corpus with a handful of known-good shapes. + f.Add(MarshalPayload(nil, Payload{})) + f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})) + f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1})) + f.Add(MarshalPayload(nil, Payload{ + Cert: []byte("seed-cert"), + InitiatorIndex: 1, + ResponderIndex: 2, + Time: 3, + CertVersion: 2, + })) + f.Add([]byte{}) + f.Add([]byte{0xff}) + + f.Fuzz(func(t *testing.T, data []byte) { + p1, err := UnmarshalPayload(data) + if err != nil { + return + } + + // For any input that parses, re-marshaling and re-parsing must + // yield an equivalent Payload. This catches dispatch bugs (e.g. + // emitting a field on marshal that we don't accept on parse) and + // any non-idempotent parsing behavior. + b2 := MarshalPayload(nil, p1) + p2, err := UnmarshalPayload(b2) + if err != nil { + t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2) + } + if !payloadsEqual(p1, p2) { + t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2) + } + }) +} + +func payloadsEqual(a, b Payload) bool { + return bytes.Equal(a.Cert, b.Cert) && + a.InitiatorIndex == b.InitiatorIndex && + a.ResponderIndex == b.ResponderIndex && + a.Time == b.Time && + a.CertVersion == b.CertVersion +} diff --git a/handshake_ix.go b/handshake_ix.go deleted file mode 100644 index 4382fafc..00000000 --- a/handshake_ix.go +++ /dev/null @@ -1,746 +0,0 @@ -package nebula - -import ( - "bytes" - "net/netip" - "time" - - "github.com/flynn/noise" - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/udp" -) - -// NOISE IX Handshakes - -// This function constructs a handshake packet, but does not actually send it -// Sending is done by the handshake manager -func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { - err := f.handshakeManager.allocateIndex(hh) - if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") - return false - } - - cs := f.pki.getCertState() - v := cs.initiatingVersion - if hh.initiatingVersionOverride != cert.VersionPre1 { - v = hh.initiatingVersionOverride - } else if v < cert.Version2 { - // If we're connecting to a v6 address we should encourage use of a V2 cert - for _, a := range hh.hostinfo.vpnAddrs { - if a.Is6() { - v = cert.Version2 - break - } - } - } - - crt := cs.getCertificate(v) - if crt == nil { - f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Unable to handshake with host because no certificate is available") - return false - } - - crtHs := cs.getHandshakeBytes(v) - if crtHs == nil { - f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Unable to handshake with host because no certificate handshake bytes is available") - return false - } - - ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) - if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Failed to create connection state") - return false - } - hh.hostinfo.ConnectionState = ci - - hs := &NebulaHandshake{ - Details: &NebulaHandshakeDetails{ - InitiatorIndex: hh.hostinfo.localIndexId, - Time: uint64(time.Now().UnixNano()), - Cert: crtHs, - CertVersion: uint32(v), - }, - } - - if f.multiPort.Tx || f.multiPort.Rx { - hs.Details.InitiatorMultiPort = &MultiPortDetails{ - RxSupported: f.multiPort.Rx, - TxSupported: f.multiPort.Tx, - BasePort: uint32(f.multiPort.TxBasePort), - TotalPorts: uint32(f.multiPort.TxPorts), - } - } - - hsBytes, err := hs.Marshal() - if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("certVersion", v). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") - return false - } - - h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) - - msg, _, _, err := ci.H.WriteMessage(h, hsBytes) - if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") - return false - } - - // We are sending handshake packet 1, so we don't expect to receive - // handshake packet 1 from the responder - ci.window.Update(f.l, 1) - - hh.hostinfo.HandshakePacket[0] = msg - hh.ready = true - return true -} - -func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) { - cs := f.pki.getCertState() - crt := cs.GetDefaultCertificate() - if crt == nil { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", cs.initiatingVersion). - Error("Unable to handshake with host because no certificate is available") - return - } - - ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) - if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to create connection state") - return - } - - // Mark packet 1 as seen so it doesn't show up as missed - ci.window.Update(f.l, 1) - - msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) - if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to call noise.ReadMessage") - return - } - - hs := &NebulaHandshake{} - err = hs.Unmarshal(msg) - if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed unmarshal handshake message") - return - } - - rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) - if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake did not contain a certificate") - return - } - - remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) - if err != nil { - fp, fperr := rc.Fingerprint() - if fperr != nil { - fp = "" - } - - e := f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("certVpnNetworks", rc.Networks()). - WithField("certFingerprint", fp) - - if f.l.Level >= logrus.DebugLevel { - e = e.WithField("cert", rc) - } - - e.Info("Invalid certificate from host") - return - } - - if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") - return - } - - if remoteCert.Certificate.Version() != ci.myCert.Version() { - // We started off using the wrong certificate version, lets see if we can match the version that was sent to us - myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) - if myCertOtherVersion == nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithError(err).WithFields(m{ - "from": via, - "handshake": m{"stage": 1, "style": "ix_psk0"}, - "cert": remoteCert, - }).Debug("Might be unable to handshake with host due to missing certificate version") - } - } else { - // Record the certificate we are actually using - ci.myCert = myCertOtherVersion - } - } - - if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("from", via). - WithField("cert", remoteCert). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("No networks in certificate") - return - } - - certName := remoteCert.Certificate.Name() - certVersion := remoteCert.Certificate.Version() - fingerprint := remoteCert.Fingerprint - issuer := remoteCert.Certificate.Issuer() - vpnNetworks := remoteCert.Certificate.Networks() - - anyVpnAddrsInCommon := false - vpnAddrs := make([]netip.Addr, len(vpnNetworks)) - for i, network := range vpnNetworks { - if f.myVpnAddrsTable.Contains(network.Addr()) { - f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") - return - } - vpnAddrs[i] = network.Addr() - if f.myVpnNetworksTable.Contains(network.Addr()) { - anyVpnAddrsInCommon = true - } - } - - if !via.IsRelayed { - // We only want to apply the remote allow list for direct tunnels here - if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - Debug("lighthouse.remote_allow_list denied incoming handshake") - return - } - } - - myIndex, err := generateIndex(f.l) - if err != nil { - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") - return - } - - var multiportTx, multiportRx bool - if f.multiPort.Rx || f.multiPort.Tx { - if hs.Details.InitiatorMultiPort != nil { - multiportTx = hs.Details.InitiatorMultiPort.RxSupported && f.multiPort.Tx - multiportRx = hs.Details.InitiatorMultiPort.TxSupported && f.multiPort.Rx - } - - hs.Details.ResponderMultiPort = &MultiPortDetails{ - TxSupported: f.multiPort.Tx, - RxSupported: f.multiPort.Rx, - BasePort: uint32(f.multiPort.TxBasePort), - TotalPorts: uint32(f.multiPort.TxPorts), - } - } - if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(via.UdpAddr.Port()) { - // The other side sent us a handshake from a different port, make sure - // we send responses back to the BasePort - via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), uint16(hs.Details.InitiatorMultiPort.BasePort)) - } - - hostinfo := &HostInfo{ - ConnectionState: ci, - localIndexId: myIndex, - remoteIndexId: hs.Details.InitiatorIndex, - vpnAddrs: vpnAddrs, - HandshakePacket: make(map[uint8][]byte, 0), - lastHandshakeTime: hs.Details.Time, - multiportTx: multiportTx, - multiportRx: multiportRx, - relayState: RelayState{ - relays: nil, - relayForByAddr: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, - }, - } - - msgRxL := f.l.WithFields(m{ - "vpnAddrs": vpnAddrs, - "from": via, - "certName": certName, - "certVersion": certVersion, - "fingerprint": fingerprint, - "issuer": issuer, - "initiatorIndex": hs.Details.InitiatorIndex, - "responderIndex": hs.Details.ResponderIndex, - "remoteIndex": h.RemoteIndex, - "multiportTx": multiportTx, - "multiportRx": multiportRx, - "handshake": m{"stage": 1, "style": "ix_psk0"}, - }) - - if anyVpnAddrsInCommon { - msgRxL.Info("Handshake message received") - } else { - //todo warn if not lighthouse or relay? - msgRxL.Info("Handshake message received, but no vpnNetworks in common.") - } - - hs.Details.ResponderIndex = myIndex - hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) - if hs.Details.Cert == nil { - msgRxL.WithField("myCertVersion", ci.myCert.Version()). - Error("Unable to handshake with host because no certificate handshake bytes is available") - return - } - - hs.Details.CertVersion = uint32(ci.myCert.Version()) - // Update the time in case their clock is way off from ours - hs.Details.Time = uint64(time.Now().UnixNano()) - - hsBytes, err := hs.Marshal() - if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") - return - } - - nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) - msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) - if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") - return - } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") - return - } - - hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:])) - copy(hostinfo.HandshakePacket[0], packet[header.Len:]) - - // Regardless of whether you are the sender or receiver, you should arrive here - // and complete standing up the connection. - hostinfo.HandshakePacket[2] = make([]byte, len(msg)) - copy(hostinfo.HandshakePacket[2], msg) - - // We are sending handshake packet 2, so we don't expect to receive - // handshake packet 2 from the initiator. - ci.window.Update(f.l, 2) - - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - - hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) - if !via.IsRelayed { - hostinfo.SetRemote(via.UdpAddr) - } - hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) - - existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) - if err != nil { - switch err { - case ErrAlreadySeen: - if hostinfo.multiportRx { - // The other host is sending to us with multiport, so only grab the IP - via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port()) - } - // Update remote if preferred - if existing.SetRemoteIfPreferred(f.hostMap, via) { - // Send a test packet to ensure the other side has also switched to - // the preferred remote - f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - } - - msg = existing.HandshakePacket[2] - f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if !via.IsRelayed { - err := f.outside.WriteTo(msg, via.UdpAddr) - if multiportTx { - // TODO remove alloc here - raw := make([]byte, len(msg)+udp.RawOverhead) - copy(raw[udp.RawOverhead:], msg) - err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr) - } else { - err = f.outside.WriteTo(msg, via.UdpAddr) - } - if err != nil { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - WithError(err).Error("Failed to send handshake message") - } else { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") - } - return - } else { - if via.relay == nil { - f.l.Error("Handshake send failed: both addr and via.relay are nil.") - return - } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) - f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") - return - } - case ErrExistingHostInfo: - // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("oldHandshakeTime", existing.lastHandshakeTime). - WithField("newHandshakeTime", hostinfo.lastHandshakeTime). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake too old") - - // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - return - case ErrLocalIndexCollision: - // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). - Error("Failed to add HostInfo due to localIndex collision") - return - default: - // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete - // And we forget to update it here - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to add HostInfo to HostMap") - return - } - } - - // Do the send - f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if !via.IsRelayed { - if multiportTx { - // TODO remove alloc here - raw := make([]byte, len(msg)+udp.RawOverhead) - copy(raw[udp.RawOverhead:], msg) - err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr) - } else { - err = f.outside.WriteTo(msg, via.UdpAddr) - } - log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) - if err != nil { - log.WithError(err).Error("Failed to send handshake") - } else { - log.Info("Handshake message sent") - } - } else { - if via.relay == nil { - f.l.Error("Handshake send failed: both addr and via.relay are nil.") - return - } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) - // I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure - // it's correctly marked as working. - via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) - f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake message sent") - } - - f.connectionManager.AddTrafficWatch(hostinfo) - - hostinfo.remotes.RefreshFromHandshake(vpnAddrs) - - return -} - -func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { - if hh == nil { - // Nothing here to tear down, got a bogus stage 2 packet - return true - } - - hh.Lock() - defer hh.Unlock() - - hostinfo := hh.hostinfo - if !via.IsRelayed { - // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. - if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") - return false - } - } - - ci := hostinfo.ConnectionState - msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) - if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Error("Failed to call noise.ReadMessage") - - // We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying - // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the - // near future - return false - } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Error("Noise did not arrive at a key") - - // This should be impossible in IX but just in case, if we get here then there is no chance to recover - // the handshake state machine. Tear it down - return true - } - - hs := &NebulaHandshake{} - err = hs.Unmarshal(msg) - if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") - - // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again - return true - } - - if (f.multiPort.Tx || f.multiPort.Rx) && hs.Details.ResponderMultiPort != nil { - hostinfo.multiportTx = hs.Details.ResponderMultiPort.RxSupported && f.multiPort.Tx - hostinfo.multiportRx = hs.Details.ResponderMultiPort.TxSupported && f.multiPort.Rx - } - - if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(via.UdpAddr.Port()) { - // The other side sent us a handshake from a different port, make sure - // we send responses back to the BasePort - via.UdpAddr = netip.AddrPortFrom( - via.UdpAddr.Addr(), - uint16(hs.Details.ResponderMultiPort.BasePort), - ) - } - - rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) - if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake did not contain a certificate") - return true - } - - remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) - if err != nil { - fp, err := rc.Fingerprint() - if err != nil { - fp = "" - } - - e := f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("certFingerprint", fp). - WithField("certVpnNetworks", rc.Networks()) - - if f.l.Level >= logrus.DebugLevel { - e = e.WithField("cert", rc) - } - - e.Info("Invalid certificate from host") - return true - } - if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") - return true - } - - if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("cert", remoteCert). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("No networks in certificate") - return true - } - - vpnNetworks := remoteCert.Certificate.Networks() - certName := remoteCert.Certificate.Name() - certVersion := remoteCert.Certificate.Version() - fingerprint := remoteCert.Fingerprint - issuer := remoteCert.Certificate.Issuer() - - hostinfo.remoteIndexId = hs.Details.ResponderIndex - hostinfo.lastHandshakeTime = hs.Details.Time - - // Store their cert and our symmetric keys - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - - // Make sure the current udpAddr being used is set for responding - if !via.IsRelayed { - hostinfo.SetRemote(via.UdpAddr) - } else { - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) - } - - correctHostResponded := false - anyVpnAddrsInCommon := false - vpnAddrs := make([]netip.Addr, len(vpnNetworks)) - for i, network := range vpnNetworks { - vpnAddrs[i] = network.Addr() - if f.myVpnNetworksTable.Contains(network.Addr()) { - anyVpnAddrsInCommon = true - } - if hostinfo.vpnAddrs[0] == network.Addr() { - // todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not? - correctHostResponded = true - } - } - - // Ensure the right host responded - if !correctHostResponded { - f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). - WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Incorrect host responded to handshake") - - // Release our old handshake from pending, it should not continue - f.handshakeManager.DeleteHostInfo(hostinfo) - - // Create a new hostinfo/handshake for the intended vpn ip - //TODO is hostinfo.vpnAddrs[0] always the address to use? - f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { - // Block the current used address - newHH.hostinfo.remotes = hostinfo.remotes - newHH.hostinfo.remotes.BlockRemote(via) - - f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). - WithField("vpnNetworks", vpnNetworks). - WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). - Info("Blocked addresses for handshakes") - - // Swap the packet store to benefit the original intended recipient - newHH.packetStore = hh.packetStore - hh.packetStore = []*cachedPacket{} - - // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnAddrs = vpnAddrs - f.sendCloseTunnel(hostinfo) - }) - - return true - } - - // Mark packet 2 as seen so it doesn't show up as missed - ci.window.Update(f.l, 2) - - duration := time.Since(hh.startTime).Nanoseconds() - msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("durationNs", duration). - WithField("sentCachedPackets", len(hh.packetStore)). - WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx) - if anyVpnAddrsInCommon { - msgRxL.Info("Handshake message received") - } else { - //todo warn if not lighthouse or relay? - msgRxL.Info("Handshake message received, but no vpnNetworks in common.") - } - - // Build up the radix for the firewall if we have subnets in the cert - hostinfo.vpnAddrs = vpnAddrs - hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) - - // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here - f.handshakeManager.Complete(hostinfo, f) - f.connectionManager.AddTrafficWatch(hostinfo) - - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) - } - - if len(hh.packetStore) > 0 { - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for _, cp := range hh.packetStore { - cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) - } - f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) - } - - hostinfo.remotes.RefreshFromHandshake(vpnAddrs) - f.metricHandshakes.Update(duration) - - return false -} diff --git a/handshake_manager.go b/handshake_manager.go index 1e9a0956..c39f0f27 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -6,14 +6,15 @@ import ( "crypto/rand" "encoding/binary" "errors" + "log/slog" "net/netip" "slices" "sync" "time" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/handshake" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" ) @@ -22,7 +23,18 @@ const ( DefaultHandshakeTryInterval = time.Millisecond * 100 DefaultHandshakeRetries = 10 DefaultHandshakeTriggerBuffer = 64 - DefaultUseRelays = true + + // maxCachedPackets is how many unsent packets we'll buffer per pending + // handshake before dropping further ones. + maxCachedPackets = 100 + + // HandshakePacket map keys mirror the IX protocol stage convention: + // stage 0 = the initiator's first message (and what the responder + // receives, stripped of header) + // stage 2 = the responder's reply + // Other handshake patterns will need new keys when added. + handshakePacketStage0 uint8 = 0 + handshakePacketStage2 uint8 = 2 ) var ( @@ -30,7 +42,6 @@ var ( tryInterval: DefaultHandshakeTryInterval, retries: DefaultHandshakeRetries, triggerBuffer: DefaultHandshakeTriggerBuffer, - useRelays: DefaultUseRelays, } ) @@ -38,7 +49,6 @@ type HandshakeConfig struct { tryInterval time.Duration retries int64 triggerBuffer int - useRelays bool messageMetrics *MessageMetrics } @@ -59,7 +69,7 @@ type HandshakeManager struct { metricInitiated metrics.Counter metricTimedOut metrics.Counter f *Interface - l *logrus.Logger + l *slog.Logger multiPort MultiPortConfig udpRaw *udp.RawConn @@ -79,34 +89,35 @@ type HandshakeHostInfo struct { packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo + machine *handshake.Machine // The handshake state machine, set during stage 0 (initiator) or beginHandshake (responder multi-message) } -func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { - if len(hh.packetStore) < 100 { +func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { + if len(hh.packetStore) < maxCachedPackets { tempPacket := make([]byte, len(packet)) copy(tempPacket, packet) hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket}) - if l.Level >= logrus.DebugLevel { - hh.hostinfo.logger(l). - WithField("length", len(hh.packetStore)). - WithField("stored", true). - Debugf("Packet store") + if l.Enabled(context.Background(), slog.LevelDebug) { + hh.hostinfo.logger(l).Debug("Packet store", + "length", len(hh.packetStore), + "stored", true, + ) } } else { m.dropped.Inc(1) - if l.Level >= logrus.DebugLevel { - hh.hostinfo.logger(l). - WithField("length", len(hh.packetStore)). - WithField("stored", false). - Debugf("Packet store") + if l.Enabled(context.Background(), slog.LevelDebug) { + hh.hostinfo.logger(l).Debug("Packet store", + "length", len(hh.packetStore), + "stored", false, + ) } } } -func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { +func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, @@ -140,27 +151,47 @@ func (hm *HandshakeManager) Run(ctx context.Context) { } func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) { + // Gate on known handshake subtypes. Unknown subtypes (or future ones we + // don't yet support) are dropped here rather than silently routed through + // the IX path. Add a case when introducing a new pattern. + switch h.Subtype { + case header.HandshakeIXPSK0: + // supported + default: + hm.l.Debug("dropping handshake with unsupported subtype", + "from", via, "subtype", h.Subtype) + return + } + // First remote allow list check before we know the vpnIp if !via.IsRelayed { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { - hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") + hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via) return } } - switch h.Subtype { - case header.HandshakeIXPSK0: - switch h.MessageCounter { - case 1: - ixHandshakeStage1(hm.f, via, packet, h) - - case 2: - newHostinfo := hm.queryIndex(h.RemoteIndex) - tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h) - if tearDown && newHostinfo != nil { - hm.DeleteHostInfo(newHostinfo.hostinfo) - } + // First message of a new handshake. The wire format requires RemoteIndex + // to be zero here (the initiator has no responder index to fill in yet), + // and generateIndex never allocates 0, so any non-zero RemoteIndex on a + // stage-1 packet is malformed or someone probing for an index collision. + // Drop without paying the cost of running noise on a pending Machine. + if h.MessageCounter == 1 { + if h.RemoteIndex != 0 { + hm.l.Debug("dropping stage-1 handshake with non-zero RemoteIndex", + "from", via, "remoteIndex", h.RemoteIndex) + return } + hm.beginHandshake(via, packet, h) + return + } + + // Continuation message must match a pending handshake by index. + // Anything else is an orphaned packet (e.g., late retransmit after + // timeout) and is dropped. + if hh := hm.queryIndex(h.RemoteIndex); hh != nil { + hm.continueHandshake(via, hh, packet) + return } } @@ -186,12 +217,22 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo := hh.hostinfo // If we are out of time, clean up if hh.counter >= hm.config.retries { - hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())). - WithField("initiatorIndex", hh.hostinfo.localIndexId). - WithField("remoteIndex", hh.hostinfo.remoteIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("durationNs", time.Since(hh.startTime).Nanoseconds()). - Info("Handshake timed out") + fields := []any{ + "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), + "initiatorIndex", hh.hostinfo.localIndexId, + "remoteIndex", hh.hostinfo.remoteIndexId, + "durationNs", time.Since(hh.startTime).Nanoseconds(), + } + // hh.machine can be nil here if buildStage0Packet never succeeded + // (e.g., no certificate available). In that case there's no useful + // handshake metadata to log. + if hh.machine != nil { + fields = append(fields, "handshake", m{ + "stage": uint64(hh.machine.MessageIndex()), + "style": header.SubTypeName(header.Handshake, hh.machine.Subtype()), + }) + } + hh.hostinfo.logger(hm.l).Info("Handshake timed out", fields...) hm.metricTimedOut.Inc(1) hm.DeleteHostInfo(hostinfo) return @@ -202,12 +243,25 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // Check if we have a handshake packet to transmit yet if !hh.ready { - if !ixHandshakeStage0(hm.f, hh) { + if !hm.buildStage0Packet(hh) { hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) return } } + // TODO: this hardcodes "always retransmit stage 0", which is correct for + // IX (the initiator only ever sends one packet, msg1) but wrong the + // moment a 3+ message pattern lands. The retry loop should resend the + // most recent outgoing message, not always stage 0. That implies + // HandshakeHostInfo tracking a single "currentOutbound" packet (bytes + + // header metadata) that gets replaced as the handshake progresses, + // instead of indexing into HandshakePacket. + stage0 := hostinfo.HandshakePacket[handshakePacketStage0] + hsFields := m{ + "stage": uint64(hh.machine.MessageIndex()), + "style": header.SubTypeName(header.Handshake, hh.machine.Subtype()), + } + // Get a remotes object if we don't already have one. // This is mainly to protect us as this should never be the case // NB ^ This comment doesn't jive. It's how the thing gets initialized. @@ -242,13 +296,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered var sentTo []netip.AddrPort var sentMultiport bool hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { - hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + hm.messageMetrics.Tx(header.Handshake, hh.machine.Subtype(), 1) + err := hm.outside.WriteTo(stage0, addr) if err != nil { - hostinfo.logger(hm.l).WithField("udpAddr", addr). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake message") + hostinfo.logger(hm.l).Error("Failed to send handshake message", + "udpAddr", addr, + "initiatorIndex", hostinfo.localIndexId, + "handshake", hsFields, + "error", err, + ) } else { sentTo = append(sentTo, addr) @@ -279,162 +335,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, // so only log when the list of remotes has changed if remotesHaveChanged { - hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("multiportHandshake", sentMultiport). - Info("Handshake message sent") - } else if hm.l.Level >= logrus.DebugLevel { - hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Debug("Handshake message sent") + hostinfo.logger(hm.l).Info("Handshake message sent", + "udpAddrs", sentTo, + "initiatorIndex", hostinfo.localIndexId, + "handshake", hsFields, + "multiportHandshake", sentMultiport, + ) + } else if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(hm.l).Debug("Handshake message sent", + "udpAddrs", sentTo, + "initiatorIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) } - if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { - hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") - // Send a RelayRequest to all known Relay IP's - for _, relay := range hostinfo.remotes.relays { - // Don't relay through the host I'm trying to connect to - if relay == vpnIp { - continue - } - - // Don't relay to myself - if hm.f.myVpnAddrsTable.Contains(relay) { - continue - } - - relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) - if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - hm.f.Handshake(relay) - continue - } - // Check the relay HostInfo to see if we already established a relay through - existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) - if !ok { - // No relays exist or requested yet. - if relayHostInfo.remote.IsValid() { - idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) - if err != nil { - hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") - } - - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: idx, - } - - switch relayHostInfo.GetCert().Certificate.Version() { - case cert.Version1: - if !hm.f.myVpnAddrs[0].Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") - continue - } - - if !vpnIp.Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") - continue - } - - b := hm.f.myVpnAddrs[0].As4() - m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) - b = vpnIp.As4() - m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) - case cert.Version2: - m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) - m.RelayToAddr = netAddrToProtoAddr(vpnIp) - default: - hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") - continue - } - - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") - } else { - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], - "relayTo": vpnIp, - "initiatorRelayIndex": idx, - "relay": relay}). - Info("send CreateRelayRequest") - } - } - continue - } - - switch existingRelay.State { - case Established: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") - hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) - case Disestablished: - // Mark this relay as 'requested' - relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) - fallthrough - case Requested: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") - // Re-send the CreateRelay request, in case the previous one was lost. - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: existingRelay.LocalIndex, - } - - switch relayHostInfo.GetCert().Certificate.Version() { - case cert.Version1: - if !hm.f.myVpnAddrs[0].Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") - continue - } - - if !vpnIp.Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") - continue - } - - b := hm.f.myVpnAddrs[0].As4() - m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) - b = vpnIp.As4() - m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) - case cert.Version2: - m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) - m.RelayToAddr = netAddrToProtoAddr(vpnIp) - default: - hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") - continue - } - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") - } else { - // This must send over the hostinfo, not over hm.Hosts[ip] - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], - "relayTo": vpnIp, - "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": relay}). - Info("send CreateRelayRequest") - } - case PeerRequested: - // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. - fallthrough - default: - hostinfo.logger(hm.l). - WithField("vpnIp", vpnIp). - WithField("state", existingRelay.State). - WithField("relay", relay). - Errorf("Relay unexpected state") - - } - } - } + hm.f.relayManager.StartRelays(hm.f, vpnIp, hostinfo, stage0) // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { @@ -575,9 +490,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). - Info("New host shadows existing host remoteIndex") + hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", + "remoteIndex", hostinfo.remoteIndexId, + "collision", existingRemoteIndex.vpnAddrs, + ) } hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) @@ -597,9 +513,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). - Info("New host shadows existing host remoteIndex") + hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", + "remoteIndex", hostinfo.remoteIndexId, + "collision", existingRemoteIndex.vpnAddrs, + ) } // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. @@ -610,16 +527,16 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // allocateIndex generates a unique localIndexId for this HostInfo // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId -func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { +func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) (uint32, error) { hm.mainHostMap.RLock() defer hm.mainHostMap.RUnlock() hm.Lock() defer hm.Unlock() - for i := 0; i < 32; i++ { + for range 32 { index, err := generateIndex(hm.l) if err != nil { - return err + return 0, err } _, inPending := hm.indexes[index] @@ -628,11 +545,11 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { if !inMain && !inPending { hh.hostinfo.localIndexId = index hm.indexes[index] = hh - return nil + return index, nil } } - return errors.New("failed to generate unique localIndexId") + return 0, errors.New("failed to generate unique localIndexId") } func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { @@ -655,10 +572,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { hm.indexes = map[uint32]*HandshakeHostInfo{} } - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), - "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). - Debug("Pending hostmap hostInfo deleted") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Pending hostmap hostInfo deleted", + "hostMap", m{"mapTotalSize": len(hm.vpnIps), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}, + ) } } @@ -726,7 +644,7 @@ func (hm *HandshakeManager) EmitStats() { // Utility functions below -func generateIndex(l *logrus.Logger) (uint32, error) { +func generateIndex(l *slog.Logger) (uint32, error) { b := make([]byte, 4) // Let zero mean we don't know the ID, so don't generate zero @@ -734,16 +652,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) { for index == 0 { _, err := rand.Read(b) if err != nil { - l.Errorln(err) + l.Error("Failed to generate index", "error", err) return 0, err } index = binary.BigEndian.Uint32(b) } - if l.Level >= logrus.DebugLevel { - l.WithField("index", index). - Debug("Generated index") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Generated index", "index", index) } return index, nil } @@ -751,3 +668,524 @@ func generateIndex(l *logrus.Logger) (uint32, error) { func hsTimeout(tries int64, interval time.Duration) time.Duration { return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval))) } + +// buildStage0Packet creates the initial handshake packet for the initiator. +func (hm *HandshakeManager) buildStage0Packet(hh *HandshakeHostInfo) bool { + cs := hm.f.pki.getCertState() + v := cs.DefaultVersion() + if hh.initiatingVersionOverride != cert.VersionPre1 { + v = hh.initiatingVersionOverride + } else if v < cert.Version2 { + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } + } + } + + cred := cs.GetCredential(v) + if cred == nil { + hm.f.l.Error("Unable to handshake with host because no certificate is available", + "vpnAddrs", hh.hostinfo.vpnAddrs, "certVersion", v) + return false + } + + machine, err := handshake.NewMachine( + v, cs.GetCredential, + hm.certVerifier(), func() (uint32, error) { return hm.allocateIndex(hh) }, + true, header.HandshakeIXPSK0, + ) + if err != nil { + hm.f.l.Error("Failed to create handshake machine", + "vpnAddrs", hh.hostinfo.vpnAddrs, "error", err) + return false + } + + msg, err := machine.Initiate(nil) + if err != nil { + hm.f.l.Error("Failed to initiate handshake", + "vpnAddrs", hh.hostinfo.vpnAddrs, "error", err) + return false + } + + // hostinfo.ConnectionState stays nil until the handshake completes in + // continueHandshake. Pre-completion control surfaces guard with nil + // checks; the data plane never observes a pending hostinfo. + hh.hostinfo.HandshakePacket[handshakePacketStage0] = msg + hh.machine = machine + hh.ready = true + return true +} + +// beginHandshake handles an incoming handshake packet that doesn't match any +// existing pending handshake. It creates a new responder Machine and processes +// the first message. +func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *header.H) { + f := hm.f + cs := f.pki.getCertState() + + v := cs.DefaultVersion() + if cs.GetCredential(v) == nil { + f.l.Error("Unable to handshake with host because no certificate is available", + "from", via, "certVersion", v) + return + } + + machine, err := handshake.NewMachine( + v, cs.GetCredential, + hm.certVerifier(), func() (uint32, error) { return generateIndex(f.l) }, + false, header.HandshakeIXPSK0, + ) + if err != nil { + f.l.Error("Failed to create handshake machine", "from", via, "error", err) + return + } + + response, result, err := machine.ProcessPacket(nil, packet) + if err != nil { + f.l.Error("Failed to process handshake packet", "from", via, "error", err) + return + } + + if result == nil { + // Multi-message pattern: the responder Machine would need to be + // registered in hm.indexes so a future inbound packet finds it via + // continueHandshake. The current manager doesn't do that yet, so + // fail loudly rather than silently dropping the in-flight handshake. + // TODO: support multi-message responder flows (XX, pqIX, etc.). + // See also the IX-shaped cipher key assignment in handshake.Machine. + f.l.Error("multi-message handshake responder is not supported", + "from", via, "error", handshake.ErrMultiMessageUnsupported) + return + } + + remoteCert := result.RemoteCert + if remoteCert == nil { + f.l.Error("Handshake did not produce a peer certificate", "from", via) + return + } + + // Validate peer identity + vpnAddrs, anyVpnAddrsInCommon, ok := hm.validatePeerCert(via, remoteCert) + if !ok { + return + } + + hostinfo := &HostInfo{ + ConnectionState: newConnectionStateFromResult(result), + localIndexId: result.LocalIndex, + remoteIndexId: result.RemoteIndex, + vpnAddrs: vpnAddrs, + HandshakePacket: make(map[uint8][]byte, 0), + lastHandshakeTime: result.HandshakeTime, + relayState: RelayState{ + relays: nil, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, + } + + msg := "Handshake message received" + if !anyVpnAddrsInCommon { + msg = "Handshake message received, but no vpnNetworks in common." + } + f.l.Info(msg, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "initiatorIndex", result.RemoteIndex, + "responderIndex", result.LocalIndex, + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + ) + + // packet aliases the listener's incoming buffer, so this copy must stay. + hostinfo.HandshakePacket[handshakePacketStage0] = make([]byte, len(packet[header.Len:])) + copy(hostinfo.HandshakePacket[handshakePacketStage0], packet[header.Len:]) + + // response was freshly allocated by ProcessPacket; safe to retain directly. + if response != nil { + hostinfo.HandshakePacket[handshakePacketStage2] = response + } + + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) + } + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) + + existing, err := hm.CheckAndComplete(hostinfo, handshakePacketStage0, f) + if err != nil { + hm.handleCheckAndCompleteError(err, existing, hostinfo, via) + return + } + + hm.sendHandshakeResponse(via, response, hostinfo, false) + f.connectionManager.AddTrafficWatch(hostinfo) + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) + + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } +} + +// continueHandshake feeds an incoming packet to an existing pending handshake Machine. +func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostInfo, packet []byte) { + f := hm.f + + hh.Lock() + defer hh.Unlock() + + // Re-verify hh is still tracked. Between queryIndex returning and us taking + // hh.Lock, handleOutbound may have timed out and deleted it. Once we hold + // hh.Lock no other deleter can race our index: handleOutbound also takes + // hh.Lock first, and handleRecvError targets a main-hostmap entry with a + // different localIndexId. + hm.RLock() + cur, ok := hm.indexes[hh.hostinfo.localIndexId] + hm.RUnlock() + if !ok || cur != hh { + return + } + + hostinfo := hh.hostinfo + if !via.IsRelayed { + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", hostinfo.vpnAddrs, "from", via) + return + } + } + + machine := hh.machine + if machine == nil { + f.l.Error("No handshake machine available for continuation", + "vpnAddrs", hostinfo.vpnAddrs, "from", via) + hm.DeleteHostInfo(hostinfo) + return + } + + response, result, err := machine.ProcessPacket(nil, packet) + if err != nil { + // Recoverable errors are routine noise, log at Debug. Fatal errors get a Warn. + if machine.Failed() { + f.l.Warn("Failed to process handshake packet, abandoning", + "vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err) + hm.DeleteHostInfo(hostinfo) + } else { + f.l.Debug("Failed to process handshake packet", + "vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err) + } + return + } + + if response != nil { + hm.sendHandshakeResponse(via, response, hostinfo, false) + } + + if result == nil { + return + } + + // Handshake complete; build the ConnectionState now that we have keys and a verified peer cert. + hostinfo.ConnectionState = newConnectionStateFromResult(result) + + remoteCert := result.RemoteCert + if remoteCert == nil { + f.l.Error("Handshake completed without peer certificate", + "vpnAddrs", hostinfo.vpnAddrs, "from", via) + hm.DeleteHostInfo(hostinfo) + return + } + + vpnNetworks := remoteCert.Certificate.Networks() + hostinfo.remoteIndexId = result.RemoteIndex + hostinfo.lastHandshakeTime = result.HandshakeTime + + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) + } else { + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + } + + // Verify correct host responded (initiator check) + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + correctHostResponded := false + anyVpnAddrsInCommon := false + for i, network := range vpnNetworks { + // inside.go drops self-routed packets at the firewall stage, but we'd + // rather not let a self-handshake complete in the first place: it + // wastes a hostmap slot, suppresses no log, and obscures routing + // misconfig. Explicit refusal here mirrors the responder-side check + // in validatePeerCert. + if f.myVpnAddrsTable.Contains(network.Addr()) { + f.l.Error("Refusing to handshake with myself", + "vpnNetworks", vpnNetworks, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + ) + hm.DeleteHostInfo(hostinfo) + return + } + vpnAddrs[i] = network.Addr() + if hostinfo.vpnAddrs[0] == network.Addr() { + correctHostResponded = true + } + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true + } + } + + if !correctHostResponded { + f.l.Info("Incorrect host responded to handshake", + "intendedVpnAddrs", hostinfo.vpnAddrs, + "haveVpnNetworks", vpnNetworks, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + ) + + hm.DeleteHostInfo(hostinfo) + hm.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { + newHH.hostinfo.remotes = hostinfo.remotes + newHH.hostinfo.remotes.BlockRemote(via) + newHH.packetStore = hh.packetStore + hh.packetStore = []*cachedPacket{} + hostinfo.vpnAddrs = vpnAddrs + f.sendCloseTunnel(hostinfo) + }) + return + } + + duration := time.Since(hh.startTime).Nanoseconds() + msg := "Handshake message received" + if !anyVpnAddrsInCommon { + msg = "Handshake message received, but no vpnNetworks in common." + } + f.l.Info(msg, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "initiatorIndex", result.LocalIndex, + "responderIndex", result.RemoteIndex, + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + "durationNs", duration, + "sentCachedPackets", len(hh.packetStore), + ) + + hostinfo.vpnAddrs = vpnAddrs + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) + + hm.Complete(hostinfo, f) + f.connectionManager.AddTrafficWatch(hostinfo) + + if len(hh.packetStore) > 0 { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Sending stored packets", "count", len(hh.packetStore)) + } + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for _, cp := range hh.packetStore { + cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) + } + f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) + } + + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) + f.metricHandshakes.Update(duration) + + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } +} + +// validatePeerCert checks the peer certificate for self-connection and remote allow list. +// Returns the VPN addrs, whether any of them fall within one of our own VPN +// networks, and true if valid; false if rejected. +func (hm *HandshakeManager) validatePeerCert(via ViaSender, remoteCert *cert.CachedCertificate) ([]netip.Addr, bool, bool) { + f := hm.f + vpnNetworks := remoteCert.Certificate.Networks() + + // The cert package rejects host certs with no networks at parse time, so + // reaching this state would mean an invariant was bypassed elsewhere. + // Refuse explicitly so downstream code (which indexes vpnAddrs[0]) can't + // panic if that invariant ever changes. + if len(vpnNetworks) == 0 { + f.l.Info("No networks in certificate", + "from", via, "cert", remoteCert) + return nil, false, false + } + + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + anyVpnAddrsInCommon := false + + for i, network := range vpnNetworks { + if f.myVpnAddrsTable.Contains(network.Addr()) { + f.l.Error("Refusing to handshake with myself", + "vpnNetworks", vpnNetworks, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + ) + return nil, false, false + } + vpnAddrs[i] = network.Addr() + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true + } + } + + if !via.IsRelayed { + if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", vpnAddrs, "from", via) + return nil, false, false + } + } + + return vpnAddrs, anyVpnAddrsInCommon, true +} + +// sendHandshakeResponse sends a handshake response via the appropriate transport. +// cached is true when msg is a stored response being retransmitted because +// the peer's stage-1 retransmit landed (the ErrAlreadySeen path); false on a +// fresh response. +func (hm *HandshakeManager) sendHandshakeResponse(via ViaSender, msg []byte, hostinfo *HostInfo, cached bool) { + if msg == nil { + return + } + + f := hm.f + f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) + + // Common log fields. peerCert may be nil during intermediate + // multi-message flows (handshake hasn't completed yet); skip the cert + // block if so. + logFields := []any{ + "vpnAddrs", hostinfo.vpnAddrs, + "handshake", m{"stage": uint64(2), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)}, + "cached", cached, + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + } + if peerCert := hostinfo.ConnectionState.peerCert; peerCert != nil { + logFields = append(logFields, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + ) + } + + if !via.IsRelayed { + fields := append(logFields, "from", via) + err := f.outside.WriteTo(msg, via.UdpAddr) + if err != nil { + f.l.Error("Failed to send handshake message", append(fields, "error", err)...) + } else { + f.l.Info("Handshake message sent", fields...) + } + } else { + if via.relay == nil { + f.l.Error("Handshake send failed: both addr and via.relay are nil.") + return + } + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + // We received a valid handshake on this relay, so make sure the relay + // state reflects that, in case it had been marked Disestablished. + via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.Info("Handshake message sent", append(logFields, "relay", via.relayHI.vpnAddrs[0])...) + } +} + +// handleCheckAndCompleteError handles errors from CheckAndComplete. +// This only fires from the responder-side beginHandshake path, after the +// peer cert has been validated and ConnectionState populated, so peerCert +// is always non-nil for the cases that log it. +func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hostinfo *HostInfo, via ViaSender) { + f := hm.f + peerCert := hostinfo.ConnectionState.peerCert + hsFields := m{"stage": uint64(1), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)} + + switch err { + case ErrAlreadySeen: + if existing.SetRemoteIfPreferred(f.hostMap, via) { + f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + } + // Resend the original response. The peer is committed to that response's + // ephemeral keys; a freshly-built one would have different keys and break + // the tunnel even though both sides "completed" the handshake. + if msg := existing.HandshakePacket[handshakePacketStage2]; msg != nil { + hm.sendHandshakeResponse(via, msg, existing, true) + } + + case ErrExistingHostInfo: + f.l.Info("Handshake too old", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + "oldHandshakeTime", existing.lastHandshakeTime, + "newHandshakeTime", hostinfo.lastHandshakeTime, + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + + case ErrLocalIndexCollision: + f.l.Error("Failed to add HostInfo due to localIndex collision", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + "localIndex", hostinfo.localIndexId, + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) + + default: + f.l.Error("Failed to add HostInfo to HostMap", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "error", err, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) + } +} + +// certVerifier returns a CertVerifier that validates certs against the current CA pool. +func (hm *HandshakeManager) certVerifier() handshake.CertVerifier { + return func(c cert.Certificate) (*cert.CachedCertificate, error) { + return hm.f.pki.GetCAPool().VerifyCertificate(time.Now(), c) + } +} diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 2e6d34b5..5f8383e4 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" @@ -27,7 +28,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -100,3 +101,137 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo { func (mw *mockEncWriter) GetCertState() *CertState { return &CertState{initiatingVersion: cert.Version2} } + +func TestValidatePeerCert(t *testing.T) { + l := test.NewLogger() + + myNetwork := netip.MustParsePrefix("10.0.0.1/24") + myAddrTable := new(bart.Lite) + myAddrTable.Insert(netip.PrefixFrom(myNetwork.Addr(), myNetwork.Addr().BitLen())) + myNetTable := new(bart.Lite) + myNetTable.Insert(myNetwork.Masked()) + + newHM := func() *HandshakeManager { + hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig) + hm.f = &Interface{ + handshakeManager: hm, + pki: &PKI{}, + l: l, + myVpnAddrsTable: myAddrTable, + myVpnNetworksTable: myNetTable, + lightHouse: hm.lightHouse, + } + return hm + } + + cached := func(networks ...netip.Prefix) *cert.CachedCertificate { + return &cert.CachedCertificate{ + Certificate: &dummyCert{name: "peer", networks: networks}, + } + } + + via := ViaSender{ + UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"), + IsRelayed: true, // skip the remote allow list (covered separately) + } + + t.Run("addr inside our networks sets anyVpnAddrsInCommon", func(t *testing.T) { + hm := newHM() + // 10.0.0.2 falls inside our 10.0.0.0/24 + addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.2/24"))) + assert.True(t, ok) + assert.True(t, common) + assert.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.2")}, addrs) + }) + + t.Run("addr outside our networks leaves anyVpnAddrsInCommon false", func(t *testing.T) { + hm := newHM() + addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("192.168.1.5/24"))) + assert.True(t, ok) + assert.False(t, common) + assert.Equal(t, []netip.Addr{netip.MustParseAddr("192.168.1.5")}, addrs) + }) + + t.Run("any matching network is enough", func(t *testing.T) { + hm := newHM() + addrs, common, ok := hm.validatePeerCert(via, cached( + netip.MustParsePrefix("192.168.1.5/24"), + netip.MustParsePrefix("10.0.0.42/24"), + )) + assert.True(t, ok) + assert.True(t, common) + assert.Len(t, addrs, 2) + }) + + t.Run("self-handshake is rejected", func(t *testing.T) { + hm := newHM() + // 10.0.0.1 is in myVpnAddrsTable + addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.1/24"))) + assert.False(t, ok) + assert.False(t, common) + assert.Nil(t, addrs) + }) + + t.Run("cert with no networks is rejected", func(t *testing.T) { + hm := newHM() + addrs, common, ok := hm.validatePeerCert(via, cached()) + assert.False(t, ok) + assert.False(t, common) + assert.Nil(t, addrs) + }) +} + +func TestHandleIncomingDispatch(t *testing.T) { + l := test.NewLogger() + + newHM := func() *HandshakeManager { + hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig) + hm.f = &Interface{ + handshakeManager: hm, + pki: &PKI{}, + l: l, + } + return hm + } + + via := ViaSender{ + UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"), + IsRelayed: true, // bypass remote allow list + } + + // A packet body of zero length is fine for these tests: dispatch is + // gated on header fields, and we assert that we never reach noise/cert + // processing for any of the malformed shapes here. + pkt := make([]byte, header.Len) + + t.Run("unsupported subtype dropped", func(t *testing.T) { + hm := newHM() + h := &header.H{Type: header.Handshake, Subtype: header.MessageSubType(99), MessageCounter: 1} + hm.HandleIncoming(via, pkt, h) + assert.Empty(t, hm.indexes, "no pending handshake should be created") + }) + + t.Run("stage-1 with non-zero RemoteIndex dropped", func(t *testing.T) { + hm := newHM() + h := &header.H{ + Type: header.Handshake, + Subtype: header.HandshakeIXPSK0, + RemoteIndex: 0xdeadbeef, + MessageCounter: 1, + } + hm.HandleIncoming(via, pkt, h) + assert.Empty(t, hm.indexes, "spoofed stage-1 must not create a pending machine") + }) + + t.Run("continuation with no matching pending index dropped", func(t *testing.T) { + hm := newHM() + h := &header.H{ + Type: header.Handshake, + Subtype: header.HandshakeIXPSK0, + RemoteIndex: 0xcafef00d, + MessageCounter: 2, + } + hm.HandleIncoming(via, pkt, h) + assert.Empty(t, hm.indexes, "orphan stage-2 must not create state") + }) +} diff --git a/header/header.go b/header/header.go index f22509b8..b973141f 100644 --- a/header/header.go +++ b/header/header.go @@ -174,6 +174,10 @@ func (h *H) SubTypeName() string { return SubTypeName(h.Type, h.Subtype) } +func (h *H) IsValidSubType() bool { + return IsValidSubType(h.Type, h.Subtype) +} + // SubTypeName will transform a nebula message sub type into a human string func SubTypeName(t MessageType, s MessageSubType) string { if n, ok := subTypeMap[t]; ok { @@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string { return "unknown" } +func IsValidSubType(t MessageType, s MessageSubType) bool { + if n, ok := subTypeMap[t]; ok { + if _, ok := (*n)[s]; ok { + return true + } + } + + return false +} + // NewHeader turns bytes into a header func NewHeader(b []byte) (*H, error) { h := new(H) diff --git a/hostmap.go b/hostmap.go index de863d87..4ffe319d 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,9 +1,11 @@ package nebula import ( + "context" "encoding/json" "errors" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -13,10 +15,10 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" ) const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address @@ -60,7 +62,7 @@ type HostMap struct { RemoteIndexes map[uint32]*HostInfo Hosts map[netip.Addr]*HostInfo preferredRanges atomic.Pointer[[]netip.Prefix] - l *logrus.Logger + l *slog.Logger } // For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay @@ -319,7 +321,7 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { +func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap { hm := newHostMap(l) hm.reload(c, true) @@ -327,13 +329,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { hm.reload(c, false) }) - l.WithField("preferredRanges", hm.GetPreferredRanges()). - Info("Main HostMap created") + l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges()) return hm } -func newHostMap(l *logrus.Logger) *HostMap { +func newHostMap(l *slog.Logger) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, @@ -352,7 +353,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) { preferredRange, err := netip.ParsePrefix(rawPreferredRange) if err != nil { - hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") + hm.l.Warn("Failed to parse preferred ranges, ignoring", + "error", err, + "range", rawPreferredRanges, + ) continue } @@ -361,7 +365,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) { oldRanges := hm.preferredRanges.Swap(&preferredRanges) if !initial { - hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed") + hm.l.Info("preferred_ranges changed", + "oldPreferredRanges", *oldRanges, + "newPreferredRanges", preferredRanges, + ) } } } @@ -494,10 +501,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad hm.Indexes = map[uint32]*HostInfo{} } - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), - "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). - Debug("Hostmap hostInfo deleted") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Hostmap hostInfo deleted", + "hostMap", m{"mapTotalSize": len(hm.Hosts), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}, + ) } if isLastHostinfo { @@ -610,9 +618,9 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. // If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { - if f.serveDns { + if f.dnsServer != nil { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) + f.dnsServer.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } for _, addr := range hostinfo.vpnAddrs { hm.unlockedInnerAddHostInfo(addr, hostinfo, f) @@ -621,10 +629,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). - Debug("Hostmap vpnIp added") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Hostmap vpnIp added", + "hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}, + ) } } @@ -790,18 +799,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica } } -func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { +// logger returns a derived slog.Logger with per-hostinfo fields pre-bound. +func (i *HostInfo) logger(l *slog.Logger) *slog.Logger { if i == nil { - return logrus.NewEntry(l) + return l } - li := l.WithField("vpnAddrs", i.vpnAddrs). - WithField("localIndex", i.localIndexId). - WithField("remoteIndex", i.remoteIndexId) + li := l.With( + "vpnAddrs", i.vpnAddrs, + "localIndex", i.localIndexId, + "remoteIndex", i.remoteIndexId, + ) if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { - li = li.WithField("certName", peerCert.Certificate.Name()) + li = li.With("certName", peerCert.Certificate.Name()) } } @@ -810,14 +822,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { +func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage var finalAddrs []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) - if l.Level >= logrus.TraceLevel { - l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName") + if l.Enabled(context.Background(), logging.LevelTrace) { + l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName", + "interfaceName", i.Name, + "allow", allow, + ) } if !allow { @@ -835,8 +850,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { } if !addr.IsValid() { - if l.Level >= logrus.DebugLevel { - l.WithField("localAddr", rawAddr).Debug("addr was invalid") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("addr was invalid", "localAddr", rawAddr) } continue } @@ -844,8 +859,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { isAllowed := allowList.Allow(addr) - if l.Level >= logrus.TraceLevel { - l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") + if l.Enabled(context.Background(), logging.LevelTrace) { + l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow", + "localAddr", addr, + "allowed", isAllowed, + ) } if !isAllowed { continue diff --git a/hostmap_test.go b/hostmap_test.go index e34a4ad0..2bd7bd43 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { func TestHostMap_reload(t *testing.T) { l := test.NewLogger() - c := config.NewC(l) + c := config.NewC(test.NewLogger()) hm := NewHostMapFromConfig(l, c) diff --git a/hostmap_tester.go b/hostmap_tester.go index fe40c533..a6ac6d44 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -1,5 +1,4 @@ //go:build e2e_testing -// +build e2e_testing package nebula diff --git a/inside.go b/inside.go index feab01c3..c556d857 100644 --- a/inside.go +++ b/inside.go @@ -1,9 +1,10 @@ package nebula import ( + "context" + "log/slog" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" @@ -15,8 +16,11 @@ import ( func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Error while validating outbound packet", + "packet", packet, + "error", err, + ) } return } @@ -36,7 +40,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) if err != nil { - f.l.WithError(err).Error("Failed to forward to tun") + f.l.Error("Failed to forward to tun", "error", err) } } // Otherwise, drop. On linux, we should never see these packets - Linux @@ -55,10 +59,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet if hostinfo == nil { f.rejectInside(packet, out, q) - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddr", fwPacket.RemoteAddr). - WithField("fwPacket", fwPacket). - Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", + "vpnAddr", fwPacket.RemoteAddr, + "fwPacket", fwPacket, + ) } return } @@ -73,11 +78,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } else { f.rejectInside(packet, out, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l). - WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping outbound packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping outbound packet", + "fwPacket", fwPacket, + "reason", dropReason, + ) } } } @@ -94,7 +99,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) { _, err := f.readers[q].Write(out) if err != nil { - f.l.WithError(err).Error("Failed to write to tun") + f.l.Error("Failed to write to tun", "error", err) } } @@ -109,11 +114,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * } if len(out) > iputil.MaxRejectPacketSize { - if f.l.GetLevel() >= logrus.InfoLevel { - f.l. - WithField("packet", packet). - WithField("outPacket", out). - Info("rejectOutside: packet too big, not sending") + if f.l.Enabled(context.Background(), slog.LevelInfo) { + f.l.Info("rejectOutside: packet too big, not sending", + "packet", packet, + "outPacket", out, + ) } return } @@ -185,10 +190,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac // This would also need to interact with unsafe_route updates through reloading the config or // use of the use_system_route_table option - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("destination", destinationAddr). - WithField("originalGateway", gatewayAddr). - Debugln("Calculated gateway for ECMP not available, attempting other gateways") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways", + "destination", destinationAddr, + "originalGateway", gatewayAddr, + ) } for i := range gateways { @@ -214,17 +220,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp fp := &firewall.Packet{} err := newPacket(p, false, fp) if err != nil { - f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err) + f.l.Warn("error while parsing outgoing packet for firewall check", "error", err) return } // check if packet is in outbound fw rules dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("fwPacket", fp). - WithField("reason", dropReason). - Debugln("dropping cached packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping cached packet", + "fwPacket", fp, + "reason", dropReason, + ) } return } @@ -240,9 +247,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message }) if hostInfo == nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddr", vpnAddr). - Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes", + "vpnAddr", vpnAddr, + ) } return } @@ -298,12 +306,12 @@ func (f *Interface) SendVia(via *HostInfo, if noiseutil.EncryptLockNeeded { via.ConnectionState.writeLock.Unlock() } - via.logger(f.l). - WithField("outCap", cap(out)). - WithField("payloadLen", len(ad)). - WithField("headerLen", len(out)). - WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()). - Error("SendVia out buffer not large enough for relay") + via.logger(f.l).Error("SendVia out buffer not large enough for relay", + "outCap", cap(out), + "payloadLen", len(ad), + "headerLen", len(out), + "cipherOverhead", via.ConnectionState.eKey.Overhead(), + ) return } @@ -323,12 +331,12 @@ func (f *Interface) SendVia(via *HostInfo, via.ConnectionState.writeLock.Unlock() } if err != nil { - via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") + via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err) return } err = f.writers[0].WriteTo(out, via.remote) if err != nil { - via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") + via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err) } f.connectionManager.RelayUsed(relay.LocalIndex) } @@ -384,8 +392,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Lighthouse update triggered for punch due to rebind counter", + "vpnAddrs", hostinfo.vpnAddrs, + ) } } @@ -395,10 +405,12 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType ci.writeLock.Unlock() } if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).WithField("counter", c). - WithField("attemptedCounter", c). - Error("Failed to encrypt outgoing packet") + hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", + "error", err, + "udpAddr", remote, + "counter", c, + "attemptedCounter", c, + ) return } @@ -411,8 +423,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType err = f.writers[q].WriteTo(out, remote) } if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") + hostinfo.logger(f.l).Error("Failed to write outgoing packet", + "error", err, + "udpAddr", remote, + ) } } else if hostinfo.remote.IsValid() { if multiport { @@ -423,8 +437,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType err = f.writers[q].WriteTo(out, hostinfo.remote) } if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") + hostinfo.logger(f.l).Error("Failed to write outgoing packet", + "error", err, + "udpAddr", remote, + ) } } else { // Try to send via a relay @@ -432,7 +448,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) - hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") + hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", + "relay", relayIP, + "error", err, + ) continue } f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) diff --git a/inside_bsd.go b/inside_bsd.go index c9c7730d..dc847878 100644 --- a/inside_bsd.go +++ b/inside_bsd.go @@ -1,5 +1,4 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package nebula diff --git a/inside_generic.go b/inside_generic.go index 0bb2345a..bdcc1a6a 100644 --- a/inside_generic.go +++ b/inside_generic.go @@ -1,5 +1,4 @@ //go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd -// +build !darwin,!dragonfly,!freebsd,!netbsd,!openbsd package nebula diff --git a/interface.go b/interface.go index 6bf91f84..799ea034 100644 --- a/interface.go +++ b/interface.go @@ -5,15 +5,15 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" - "os" - "runtime" + "sync" "sync/atomic" "time" "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -30,7 +30,7 @@ type InterfaceConfig struct { pki *PKI Cipher string Firewall *Firewall - ServeDns bool + DnsServer *dnsServer HandshakeManager *HandshakeManager lightHouse *LightHouse connectionManager *connectionManager @@ -47,7 +47,7 @@ type InterfaceConfig struct { reQueryWait time.Duration ConntrackCacheTimeout time.Duration - l *logrus.Logger + l *slog.Logger } type Interface struct { @@ -58,7 +58,7 @@ type Interface struct { firewall *Firewall connectionManager *connectionManager handshakeManager *HandshakeManager - serveDns bool + dnsServer *dnsServer createTime time.Time lightHouse *LightHouse myBroadcastAddrsTable *bart.Lite @@ -86,17 +86,25 @@ type Interface struct { conntrackCacheTimeout time.Duration + ctx context.Context writers []udp.Conn readers []io.ReadWriteCloser - udpRaw *udp.RawConn + wg sync.WaitGroup + // fatalErr holds the first unexpected reader error that caused shutdown. + // nil means "no fatal error" (yet) + fatalErr atomic.Pointer[error] + // triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr + triggerShutdown func() + + udpRaw *udp.RawConn multiPort MultiPortConfig metricHandshakes metrics.Histogram messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics - l *logrus.Logger + l *slog.Logger } type MultiPortConfig struct { @@ -176,12 +184,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { cs := c.pki.getCertState() ifce := &Interface{ + ctx: ctx, pki: c.pki, hostMap: c.HostMap, outside: c.Outside, inside: c.Inside, firewall: c.Firewall, - serveDns: c.ServeDns, + dnsServer: c.DnsServer, handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, @@ -222,18 +231,21 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { // activate creates the interface on the host. After the interface is created, any // other services that want to bind listeners to its IP may do so successfully. However, // the interface isn't going to process anything until run() is called. -func (f *Interface) activate() { +func (f *Interface) activate() error { // actually turn on tun dev addr, err := f.outside.LocalAddr() if err != nil { - f.l.WithError(err).Error("Failed to get udp listen address") + f.l.Error("Failed to get udp listen address", "error", err) } - f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). - WithField("build", f.version).WithField("udpAddr", addr). - WithField("boringcrypto", boringEnabled()). - Info("Nebula interface is active") + f.l.Info("Nebula interface is active", + "interface", f.inside.Name(), + "networks", f.myVpnNetworks, + "build", f.version, + "udpAddr", addr, + "boringcrypto", boringEnabled(), + ) if f.routines > 1 { if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() { @@ -252,33 +264,58 @@ func (f *Interface) activate() { if i > 0 { reader, err = f.inside.NewMultiQueueReader() if err != nil { - f.l.Fatal(err) + return err } } f.readers[i] = reader } - if err := f.inside.Activate(); err != nil { + f.wg.Add(1) // for us to wait on Close() to return + if err = f.inside.Activate(); err != nil { + f.wg.Done() f.inside.Close() - f.l.Fatal(err) + return err } + + return nil } -func (f *Interface) run() { +func (f *Interface) run() (func() error, error) { // Launch n queues to read packets from udp for i := 0; i < f.routines; i++ { - go f.listenOut(i) + f.wg.Go(func() { + f.listenOut(i) + }) } // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { - go f.listenIn(f.readers[i], i) + f.wg.Go(func() { + f.listenIn(f.readers[i], i) + }) + } + + return func() error { + f.wg.Wait() + if e := f.fatalErr.Load(); e != nil { + return *e + } + return nil + }, nil +} + +// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one +func (f *Interface) onFatal(err error) { + swapped := f.fatalErr.CompareAndSwap(nil, &err) + if !swapped { + return + } + if f.triggerShutdown != nil { + f.triggerShutdown() } } func (f *Interface) listenOut(i int) { - runtime.LockOSThread() - var li udp.Conn if i > 0 { li = f.writers[i] @@ -286,42 +323,47 @@ func (f *Interface) listenOut(i int) { li = f.outside } - ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() plaintext := make([]byte, udp.MTU) h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) }) + + if err != nil && !f.closed.Load() { + f.l.Error("Error while reading inbound packet, closing", "error", err) + f.onFatal(err) + } + + f.l.Debug("underlay reader is done", "reader", i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { - runtime.LockOSThread() - packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { - return + if !f.closed.Load() { + f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i) + f.onFatal(err) } - - f.l.WithError(err).Error("Error while reading outbound packet") - // This only seems to happen when something fatal happens to the fd, so exit. - os.Exit(2) + break } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) } + + f.l.Debug("overlay reader is done", "reader", i) } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { @@ -341,7 +383,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) { if initial || c.HasChanged("pki.disconnect_invalid") { f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true)) if !initial { - f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load()) + f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load()) } } } @@ -355,7 +397,7 @@ func (f *Interface) reloadFirewall(c *config.C) { fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { - f.l.WithError(err).Error("Error while creating firewall during reload") + f.l.Error("Error while creating firewall during reload", "error", err) return } @@ -368,10 +410,11 @@ func (f *Interface) reloadFirewall(c *config.C) { // If rulesVersion is back to zero, we have wrapped all the way around. Be // safe and just reset conntrack in this case. if fw.rulesVersion == 0 { - f.l.WithField("firewallHashes", fw.GetRuleHashes()). - WithField("oldFirewallHashes", oldFw.GetRuleHashes()). - WithField("rulesVersion", fw.rulesVersion). - Warn("firewall rulesVersion has overflowed, resetting conntrack") + f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack", + "firewallHashes", fw.GetRuleHashes(), + "oldFirewallHashes", oldFw.GetRuleHashes(), + "rulesVersion", fw.rulesVersion, + ) } else { fw.Conntrack = conntrack } @@ -379,10 +422,11 @@ func (f *Interface) reloadFirewall(c *config.C) { f.firewall = fw oldFw.Destroy() - f.l.WithField("firewallHashes", fw.GetRuleHashes()). - WithField("oldFirewallHashes", oldFw.GetRuleHashes()). - WithField("rulesVersion", fw.rulesVersion). - Info("New firewall has been installed") + f.l.Info("New firewall has been installed", + "firewallHashes", fw.GetRuleHashes(), + "oldFirewallHashes", oldFw.GetRuleHashes(), + "rulesVersion", fw.rulesVersion, + ) } func (f *Interface) reloadSendRecvError(c *config.C) { @@ -404,8 +448,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) { } } - f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()). - Info("Loaded send_recv_error config") + f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String()) } } @@ -428,8 +471,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) { } } - f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()). - Info("Loaded accept_recv_error config") + f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String()) } } @@ -505,15 +547,23 @@ func (f *Interface) GetCertState() *CertState { } func (f *Interface) Close() error { + var errs []error f.closed.Store(true) - for _, u := range f.writers { + // Release the udp readers + for i, u := range f.writers { err := u.Close() if err != nil { - f.l.WithError(err).Error("Error while closing udp socket") + f.l.Error("Error while closing udp socket", "error", err, "writer", i) + errs = append(errs, err) } } - // Release the tun device - return f.inside.Close() + // Release the tun device (closing the tun also closes all readers) + closeErr := f.inside.Close() + if closeErr != nil { + errs = append(errs, closeErr) + } + f.wg.Done() + return errors.Join(errs...) } diff --git a/lighthouse.go b/lighthouse.go index 1510b942..6034e68c 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -15,10 +16,10 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) @@ -69,18 +70,19 @@ type LightHouse struct { // Addr's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]netip.Addr] - queryChan chan netip.Addr + updateTrigger chan struct{} + queryChan chan netip.Addr calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter - l *logrus.Logger + l *slog.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -105,6 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, nebulaPort: nebulaPort, punchConn: pc, punchy: p, + updateTrigger: make(chan struct{}, 1), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } @@ -131,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, case *util.ContextualError: v.Log(l) case error: - l.WithError(err).Error("failed to reload lighthouse") + l.Error("failed to reload lighthouse", "error", err) } }) @@ -203,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used addr := addrs[0].Unmap() if lh.myVpnNetworksTable.Contains(addr) { - lh.l.WithField("addr", rawAddr).WithField("entry", i+1). - Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") + lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range", + "addr", rawAddr, + "entry", i+1, + ) continue } @@ -222,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10))) if !initial { - lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load()) + lh.l.Info("lighthouse.interval changed", + "interval", lh.interval.Load(), + ) if lh.updateCancel != nil { // May not always have a running routine @@ -316,6 +323,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { if !initial { //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic lh.l.Info("lighthouse.hosts has changed") + lh.TriggerUpdate() } } @@ -333,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { for _, v := range c.GetStringSlice("relay.relays", nil) { configRIP, err := netip.ParseAddr(v) if err != nil { - lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") + lh.l.Warn("Parse relay from config failed", + "relay", v, + "error", err, + ) } else { - lh.l.WithField("relay", v).Info("Read relay from config") + lh.l.Info("Read relay from config", "relay", v) relaysForMe = append(relaysForMe, configRIP) } } @@ -360,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) { } if !lh.myVpnNetworksTable.Contains(addr) { - lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). - Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not") + lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not", + "vpnAddr", addr, + "networks", lh.myVpnNetworks, + ) } out[i] = addr } @@ -432,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc } if !lh.myVpnNetworksTable.Contains(vpnAddr) { - lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). - Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work") + lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work", + "vpnAddr", vpnAddr, + "networks", lh.myVpnNetworks, + "entry", i+1, + ) } vals, ok := v.([]any) @@ -534,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { lh.Lock() rm, ok := lh.addrMap[allVpnAddrs[0]] if ok { + debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug) for _, addr := range allVpnAddrs { srm := lh.addrMap[addr] if srm == rm { delete(lh.addrMap, addr) - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", addr) + if debugEnabled { + lh.l.Debug("deleting from lighthouse", "vpnAddr", addr) } } } @@ -656,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddrs", vpnAddrs, + "udpAddr", to, + "allow", allow, + ) } if !allow { return false @@ -675,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { udpAddr := protoV4AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddr", vpnAddr, + "udpAddr", udpAddr, + "allow", allow, + ) } if !allow { @@ -695,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { udpAddr := protoV6AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddr", vpnAddr, + "udpAddr", udpAddr, + "allow", allow, + ) } if !allow { @@ -713,21 +739,14 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { l := lh.GetLighthouses() - for i := range l { - if l[i] == vpnAddr { - return true - } - } - return false + return slices.Contains(l, vpnAddr) } func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool { l := lh.GetLighthouses() for i := range vpnAddrs { - for j := range l { - if l[j] == vpnAddrs[i] { - return true - } + if slices.Contains(l, vpnAddrs[i]) { + return true } } return false @@ -779,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { if v == cert.Version1 { if !addr.Is4() { - lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). - Error("Can't query lighthouse for v6 address using a v1 protocol") + lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol", + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } @@ -791,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { v1Query, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("queryVpnAddr", addr). - WithField("lighthouseAddr", lhVpnAddr). - Error("Failed to marshal lighthouse v1 query payload") + lh.l.Error("Failed to marshal lighthouse v1 query payload", + "error", err, + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -808,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { v2Query, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("queryVpnAddr", addr). - WithField("lighthouseAddr", lhVpnAddr). - Error("Failed to marshal lighthouse v2 query payload") + lh.l.Error("Failed to marshal lighthouse v2 query payload", + "error", err, + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -819,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { queried++ } else { - lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) + lh.l.Debug("unsupported protocol version", + "op", "query", + "queryVpnAddr", addr, + "version", v, + ) continue } } @@ -848,11 +877,24 @@ func (lh *LightHouse) StartUpdateWorker() { return case <-clockSource.C: continue + case <-lh.updateTrigger: + continue } } }() } +// TriggerUpdate requests an immediate lighthouse update. This is a non-blocking +// operation intended to be called after a handshake completes with a lighthouse, +// so the lighthouse has our current addresses without waiting for the next +// periodic update. +func (lh *LightHouse) TriggerUpdate() { + select { + case lh.updateTrigger <- struct{}{}: + default: + } +} + func (lh *LightHouse) SendUpdate() { var v4 []*V4AddrPort var v6 []*V6AddrPort @@ -898,8 +940,9 @@ func (lh *LightHouse) SendUpdate() { if v == cert.Version1 { if v1Update == nil { if !lh.myVpnNetworks[0].Addr().Is4() { - lh.l.WithField("lighthouseAddr", lhVpnAddr). - Warn("cannot update lighthouse using v1 protocol without an IPv4 address") + lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address", + "lighthouseAddr", lhVpnAddr, + ) continue } var relays []uint32 @@ -923,8 +966,10 @@ func (lh *LightHouse) SendUpdate() { v1Update, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). - Error("Error while marshaling for lighthouse v1 update") + lh.l.Error("Error while marshaling for lighthouse v1 update", + "error", err, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -950,8 +995,10 @@ func (lh *LightHouse) SendUpdate() { v2Update, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). - Error("Error while marshaling for lighthouse v2 update") + lh.l.Error("Error while marshaling for lighthouse v2 update", + "error", err, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -960,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() { updated++ } else { - lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) + lh.l.Debug("unsupported protocol version", + "op", "update", + "version", v, + ) continue } } @@ -974,7 +1024,7 @@ type LightHouseHandler struct { out []byte pb []byte meta *NebulaMeta - l *logrus.Logger + l *slog.Logger } func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { @@ -1023,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). - Error("Failed to unmarshal lighthouse packet") + lhh.l.Error("Failed to unmarshal lighthouse packet", + "error", err, + "vpnAddrs", fromVpnAddrs, + "udpAddr", rAddr, + ) return } if n.Details == nil { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). - Error("Invalid lighthouse update") + lhh.l.Error("Invalid lighthouse update", + "vpnAddrs", fromVpnAddrs, + "udpAddr", rAddr, + ) return } @@ -1058,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I don't answer queries, but received from: ", addr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("I don't answer queries, but received one", "from", addr) } return } queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details). - Debugln("Dropping malformed HostQuery") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Dropping malformed HostQuery", + "from", fromVpnAddrs, + "details", n.Details, + ) } return } if useVersion == cert.Version1 && queryVpnAddr.Is6() { // this case really shouldn't be possible to represent, but reject it anyway. - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr). - Debugln("invalid vpn addr for v1 handleHostQuery") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("invalid vpn addr for v1 handleHostQuery", + "vpnAddrs", fromVpnAddrs, + "queryVpnAddr", queryVpnAddr, + ) } return } @@ -1101,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") + lhh.l.Error("Failed to marshal lighthouse host query reply", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1129,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd if ok { whereToPunch = newDest } else { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("unable to punch to host, no addresses in common", + "to", crt.Networks(), + ) } } } @@ -1156,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") + lhh.l.Error("Failed to marshal lighthouse host was queried for", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1198,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) } } else { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("version", v).Debug("unsupported protocol version") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("unsupported protocol version", + "op", "coalesceAnswers", + "version", v, + ) } } } @@ -1212,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Error("dropping malformed HostQueryReply", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) } return } @@ -1238,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs) } return } @@ -1262,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp //Simple check that the host sent this not someone else, if detailsVpnAddr is filled if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Host sent invalid update", + "vpnAddrs", fromVpnAddrs, + "answer", detailsVpnAddr, + ) } return } @@ -1285,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp switch useVersion { case cert.Version1: if !fromVpnAddrs[0].Is4() { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message", + "vpnAddrs", fromVpnAddrs, + ) return } vpnAddrB := fromVpnAddrs[0].As4() @@ -1293,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp case cert.Version2: // do nothing, we want to send a blank message default: - lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") + lhh.l.Error("invalid protocol version", "useVersion", useVersion) return } ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") + lhh.l.Error("Failed to marshal lighthouse host update ack", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1316,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("dropping invalid HostPunchNotification", + "details", n.Details, + "error", err, + ) } return } @@ -1334,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn lhh.lh.punchConn.WriteTo(empty, vpnPeer) }() - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Punching", + "vpnPeer", vpnPeer, + "logVpnAddr", logVpnAddr, + ) } } @@ -1360,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn if lhh.lh.punchy.GetRespond() { go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Sending a nebula test packet", + "vpnAddr", detailsVpnAddr, + ) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine diff --git a/lighthouse_test.go b/lighthouse_test.go index fea1d1ed..c57c44ec 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -1,7 +1,6 @@ package nebula import ( - "context" "encoding/binary" "fmt" "net/netip" @@ -42,14 +41,14 @@ func Test_lhStaticMapping(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}} c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + _, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}} c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + _, err = NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } @@ -71,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) { } c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -99,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { } c := config.NewC(l) - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(b.Context(), l, c, cs, nil, nil) require.NoError(b, err) hAddr := netip.MustParseAddrPort("4.5.6.7:12345") @@ -202,7 +201,7 @@ func TestLighthouse_Memory(t *testing.T) { myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) lh.ifce = &mockEncWriter{} require.NoError(t, err) lhh := lh.NewRequestHandler() @@ -288,7 +287,7 @@ func TestLighthouse_reload(t *testing.T) { myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) nc := map[string]any{ @@ -523,7 +522,7 @@ func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) { myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -589,7 +588,7 @@ func TestLighthouse_DeletesWork(t *testing.T) { myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} diff --git a/logger.go b/logger.go deleted file mode 100644 index aaf6f29c..00000000 --- a/logger.go +++ /dev/null @@ -1,45 +0,0 @@ -package nebula - -import ( - "fmt" - "strings" - "time" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" -) - -func configLogger(l *logrus.Logger, c *config.C) error { - // set up our logging level - logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) - if err != nil { - return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) - } - l.SetLevel(logLevel) - - disableTimestamp := c.GetBool("logging.disable_timestamp", false) - timestampFormat := c.GetString("logging.timestamp_format", "") - fullTimestamp := (timestampFormat != "") - if timestampFormat == "" { - timestampFormat = time.RFC3339 - } - - logFormat := strings.ToLower(c.GetString("logging.format", "text")) - switch logFormat { - case "text": - l.Formatter = &logrus.TextFormatter{ - TimestampFormat: timestampFormat, - FullTimestamp: fullTimestamp, - DisableTimestamp: disableTimestamp, - } - case "json": - l.Formatter = &logrus.JSONFormatter{ - TimestampFormat: timestampFormat, - DisableTimestamp: disableTimestamp, - } - default: - return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) - } - - return nil -} diff --git a/logging/logger.go b/logging/logger.go new file mode 100644 index 00000000..bbc10bb3 --- /dev/null +++ b/logging/logger.go @@ -0,0 +1,233 @@ +// Package logging wires the nebula runtime-reconfigurable slog handler used +// by nebula.Main and the nebula CLI binaries. Callers build a logger with +// NewLogger, then call ApplyConfig at startup and from a config reload +// callback to push logging.level, logging.format, and +// logging.disable_timestamp changes onto the logger without rebuilding it. +package logging + +import ( + "context" + "fmt" + "io" + "log/slog" + "strings" + "sync/atomic" + "time" +) + +// Config is the subset of *config.C that ApplyConfig reads. Declaring it +// here keeps the logging package from depending on config directly, which +// would cycle through the shared test helpers (test.NewLogger imports +// logging, and config's tests import test). *config.C satisfies this +// interface structurally with no adapter. +type Config interface { + GetString(key, def string) string + GetBool(key string, def bool) bool +} + +// LevelTrace is a custom slog level below Debug, used when logging.level is +// "trace". slog has no builtin trace level; the value is one step below +// slog.LevelDebug in slog's 4-point spacing. +const LevelTrace = slog.Level(-8) + +// NewLogger returns a *slog.Logger whose level, format, and timestamp +// emission can be reconfigured at runtime via ApplyConfig and the SSH debug +// commands. The default configuration is info-level text output so log +// calls made before ApplyConfig runs still produce output. Timestamps +// follow slog's default RFC3339Nano format; set logging.disable_timestamp +// in config to suppress them. +// +// ApplyConfig and the SSH commands discover the reconfig surface via +// structural type-assertion on l.Handler(), so replacement implementations +// (tests, platform-specific sinks) need only implement the subset of +// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)} +// they care about. Callers that pass a plain *slog.Logger without these +// methods get a silent no-op; reconfiguration is always opt-in. +func NewLogger(w io.Writer) *slog.Logger { + return slog.New(NewHandler(w)) +} + +// NewHandler builds the *Handler that NewLogger wraps. Exported for +// platform-specific sinks (notably cmd/nebula-service/logs_windows.go) +// that want to wrap the handler with extra behavior, such as tagging each +// record with its Event Log severity, while still benefiting from all the +// level / format / timestamp / WithAttrs machinery implemented here. +func NewHandler(w io.Writer) *Handler { + root := &handlerRoot{} + root.level.Set(slog.LevelInfo) + opts := &slog.HandlerOptions{Level: &root.level} + return &Handler{ + root: root, + text: slog.NewTextHandler(w, opts), + json: slog.NewJSONHandler(w, opts), + } +} + +// handlerRoot carries the reconfiguration state shared by every logger +// derived from a NewHandler call. All fields are consulted on the log +// path and updated lock-free. +type handlerRoot struct { + level slog.LevelVar + disableTimestamp atomic.Bool + // jsonMode picks which of the pre-derived inner handlers Handler.Handle + // dispatches to. Flipping it propagates instantly to every derived logger + // without rebuilding or chain-replaying anything. + jsonMode atomic.Bool +} + +// Handler is the slog.Handler returned by NewHandler. It holds two +// pre-derived slog handlers -- one text, one json -- both built from the +// same accumulated WithAttrs/WithGroup state. Handle picks which one to +// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes +// effect immediately across the whole process without having to rebuild +// any derived loggers. +type Handler struct { + root *handlerRoot + text slog.Handler + json slog.Handler +} + +func (h *Handler) Enabled(_ context.Context, l slog.Level) bool { + return h.root.level.Level() <= l +} + +func (h *Handler) Handle(ctx context.Context, r slog.Record) error { + if h.root.disableTimestamp.Load() { + r.Time = time.Time{} + } + if h.root.jsonMode.Load() { + return h.json.Handle(ctx, r) + } + return h.text.Handle(ctx, r) +} + +func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return h + } + return &Handler{ + root: h.root, + text: h.text.WithAttrs(attrs), + json: h.json.WithAttrs(attrs), + } +} + +func (h *Handler) WithGroup(name string) slog.Handler { + if name == "" { + return h + } + return &Handler{ + root: h.root, + text: h.text.WithGroup(name), + json: h.json.WithGroup(name), + } +} + +// SetLevel updates the effective log level. Propagates to every derived +// logger via the shared LevelVar. +func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) } + +// GetLevel reports the current log level. +func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() } + +// SetFormat flips the output format atomically. Valid formats are "text" +// and "json". Every derived logger sees the new format on its next Handle +// call; no rebuild or registration is required. +func (h *Handler) SetFormat(format string) error { + switch format { + case "text": + h.root.jsonMode.Store(false) + case "json": + h.root.jsonMode.Store(true) + default: + return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"}) + } + return nil +} + +// GetFormat reports the currently selected format name. +func (h *Handler) GetFormat() string { + if h.root.jsonMode.Load() { + return "json" + } + return "text" +} + +// SetDisableTimestamp toggles whether Handle zeroes r.Time before +// dispatching (slog's builtin text/json handlers skip emitting the time +// attribute on a zero time). +func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) } + +// ApplyConfig reads logging.level, logging.format, and (optionally) +// logging.disable_timestamp from c and applies them to l. The reconfig +// surface is discovered via structural type-assertion on l.Handler(), so +// foreign handlers silently opt out of whichever capabilities they do not +// implement. +// +// nebula.Main does NOT call this function on your behalf; callers that want +// config-driven log level / format / timestamp updates invoke it at +// startup and register it as a reload callback themselves. This keeps the +// library from mutating an embedder's logger without their say-so. +func ApplyConfig(l *slog.Logger, c Config) error { + h := l.Handler() + + lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) + if err != nil { + return err + } + if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok { + ls.SetLevel(lvl) + } + + format := strings.ToLower(c.GetString("logging.format", "text")) + if fs, ok := h.(interface{ SetFormat(string) error }); ok { + if err := fs.SetFormat(format); err != nil { + return err + } + } + + if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok { + ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false)) + } + return nil +} + +// ParseLevel converts a config-string level name ("trace", "debug", "info", +// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and +// "panic" are accepted for backwards compatibility with pre-slog configs +// and both map to slog.LevelError. +func ParseLevel(s string) (slog.Level, error) { + switch s { + case "trace": + return LevelTrace, nil + case "debug": + return slog.LevelDebug, nil + case "info": + return slog.LevelInfo, nil + case "warn", "warning": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + case "fatal", "panic": + return slog.LevelError, nil + default: + return 0, fmt.Errorf("not a valid logging level: %q", s) + } +} + +// LevelName returns a human-readable name for a slog.Level matching the +// strings accepted by ParseLevel. +func LevelName(l slog.Level) string { + switch { + case l <= LevelTrace: + return "trace" + case l <= slog.LevelDebug: + return "debug" + case l <= slog.LevelInfo: + return "info" + case l <= slog.LevelWarn: + return "warn" + default: + return "error" + } +} diff --git a/logging/logger_bench_test.go b/logging/logger_bench_test.go new file mode 100644 index 00000000..eb29c1c3 --- /dev/null +++ b/logging/logger_bench_test.go @@ -0,0 +1,90 @@ +package logging + +import ( + "context" + "io" + "log/slog" + "testing" +) + +// BenchmarkLogger_* compare the handler returned by NewLogger against a +// stock slog text handler. The key thing we care about is the per-log +// cost on a logger that has been derived via .With(), because that is the +// shape subsystems store on their structs (HostInfo.logger(), +// lh.l.With("subsystem", ...), etc.) and call from hot paths. + +func BenchmarkLogger_Stock_RootInfo(b *testing.B) { + l := slog.New(slog.DiscardHandler) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Nebula_RootInfo(b *testing.B) { + l := NewLogger(io.Discard) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) { + l := slog.New(slog.DiscardHandler).With( + "subsystem", "bench", + "localIndex", 1234, + ) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) { + l := NewLogger(io.Discard).With( + "subsystem", "bench", + "localIndex", 1234, + ) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +// Gated-off-path benchmarks: mimic the typical hot-path shape +// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below +// the active level. This is the dominant pattern in inside.go/outside.go and +// what we pay on every packet. +func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) { + l := slog.New(slog.DiscardHandler).With( + "subsystem", "bench", + "localIndex", 1234, + ) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if l.Enabled(ctx, slog.LevelDebug) { + l.Debug("hello", "i", i) + } + } +} + +func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) { + l := NewLogger(io.Discard).With( + "subsystem", "bench", + "localIndex", 1234, + ) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if l.Enabled(ctx, slog.LevelDebug) { + l.Debug("hello", "i", i) + } + } +} diff --git a/main.go b/main.go index 1da3c562..ce4d35d2 100644 --- a/main.go +++ b/main.go @@ -3,13 +3,13 @@ package nebula import ( "context" "fmt" + "log/slog" "net" "net/netip" "runtime/debug" "strings" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/sshd" @@ -20,7 +20,7 @@ import ( type m = map[string]any -func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { +func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. defer func() { @@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg buildVersion = moduleVersion() } - l := logger - l.Formatter = &logrus.TextFormatter{ - FullTimestamp: true, - } - // Print the config if in test, the exit comes later if configTest { b, err := yaml.Marshal(c.Settings) @@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } // Print the final config - l.Println(string(b)) + l.Info(string(b)) } - err := configLogger(l, c) - if err != nil { - return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) - } - - c.RegisterReloadCallback(func(c *config.C) { - err := configLogger(l, c) - if err != nil { - l.WithError(err).Error("Failed to configure the logger") - } - }) - pki, err := NewPKIFromConfig(l, c) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) @@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } - l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") + l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes()) - ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) + ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) } @@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") + l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err) sshStart = nil } } @@ -99,19 +82,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg routines = 1 } if routines > 1 { - l.WithField("routines", routines).Info("Using multiple routines") + l.Info("Using multiple routines", "routines", routines) } } else { // deprecated and undocumented tunQueues := c.GetInt("tun.routines", 1) udpQueues := c.GetInt("listen.routines", 1) - if tunQueues > udpQueues { - routines = tunQueues - } else { - routines = udpQueues - } + routines = max(tunQueues, udpQueues) if routines != 1 { - l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead") + l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines) } } @@ -124,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg conntrackCacheTimeout = 1 * time.Second } if conntrackCacheTimeout > 0 { - l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") + l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout) } var tun overlay.Device @@ -170,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } for i := 0; i < routines; i++ { - l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port))) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) @@ -205,27 +184,19 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg messageMetrics = newMessageMetricsOnlyRecvError() } - useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false) - handshakeConfig := HandshakeConfig{ - tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), - triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), - useRelays: useRelays, - + tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), + retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), + triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), messageMetrics: messageMetrics, } handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - serveDns := false - if c.GetBool("lighthouse.serve_dns", false) { - if c.GetBool("lighthouse.am_lighthouse", false) { - serveDns = true - } else { - l.Warn("DNS server refusing to run because this host is not a lighthouse.") - } + ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) + if err != nil { + l.Warn("Failed to start DNS responder", "error", err) } ifConfig := &InterfaceConfig{ @@ -234,7 +205,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg Outside: udpConns[0], pki: pki, Firewall: fw, - ServeDns: serveDns, + DnsServer: ds, HandshakeManager: handshakeManager, connectionManager: connManager, lightHouse: lightHouse, @@ -304,7 +275,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg go handshakeManager.Run(ctx) } - statsStart, err := startStats(l, c, buildVersion, configTest) + stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } @@ -317,23 +288,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg attachCommands(l, c, ssh, ifce) - // Start DNS server last to allow using the nebula IP as lighthouse.dns.host - var dnsStart func() - if lightHouse.amLighthouse && serveDns { - l.Debugln("Starting dns server") - dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) - } - return &Control{ - ifce, - l, - ctx, - cancel, - sshStart, - statsStart, - dnsStart, - lightHouse.StartUpdateWorker, - connManager.Start, + state: StateReady, + f: ifce, + l: l, + ctx: ctx, + cancel: cancel, + sshStart: sshStart, + statsStart: stats.Start, + dnsStart: ds.Start, + lighthouseStart: lightHouse.StartUpdateWorker, + connectionManagerStart: connManager.Start, }, nil } diff --git a/message_metrics.go b/message_metrics.go index 10e8472c..45de9a5c 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -13,6 +13,8 @@ type MessageMetrics struct { rxUnknown metrics.Counter txUnknown metrics.Counter + + rxInvalid metrics.Counter } func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { @@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int } } } +func (m *MessageMetrics) RxInvalid(i int64) { + if m != nil && m.rxInvalid != nil { + m.rxInvalid.Inc(i) + } +} func newMessageMetrics() *MessageMetrics { gen := func(t string) [][]metrics.Counter { @@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics { rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil), txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil), + rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil), } } diff --git a/nebula.pb.go b/nebula.pb.go index 946551b4..94a4ebe2 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string { } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{9, 0} + return fileDescriptor_2d65afa7693df5ef, []int{6, 0} } type NebulaMeta struct { @@ -489,226 +489,6 @@ func (m *NebulaPing) GetTime() uint64 { return 0 } -type NebulaHandshake struct { - Details *NebulaHandshakeDetails `protobuf:"bytes,1,opt,name=Details,proto3" json:"Details,omitempty"` - Hmac []byte `protobuf:"bytes,2,opt,name=Hmac,proto3" json:"Hmac,omitempty"` -} - -func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } -func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } -func (*NebulaHandshake) ProtoMessage() {} -func (*NebulaHandshake) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{6} -} -func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *NebulaHandshake) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_NebulaHandshake.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *NebulaHandshake) XXX_Merge(src proto.Message) { - xxx_messageInfo_NebulaHandshake.Merge(m, src) -} -func (m *NebulaHandshake) XXX_Size() int { - return m.Size() -} -func (m *NebulaHandshake) XXX_DiscardUnknown() { - xxx_messageInfo_NebulaHandshake.DiscardUnknown(m) -} - -var xxx_messageInfo_NebulaHandshake proto.InternalMessageInfo - -func (m *NebulaHandshake) GetDetails() *NebulaHandshakeDetails { - if m != nil { - return m.Details - } - return nil -} - -func (m *NebulaHandshake) GetHmac() []byte { - if m != nil { - return m.Hmac - } - return nil -} - -type MultiPortDetails struct { - RxSupported bool `protobuf:"varint,1,opt,name=RxSupported,proto3" json:"RxSupported,omitempty"` - TxSupported bool `protobuf:"varint,2,opt,name=TxSupported,proto3" json:"TxSupported,omitempty"` - BasePort uint32 `protobuf:"varint,3,opt,name=BasePort,proto3" json:"BasePort,omitempty"` - TotalPorts uint32 `protobuf:"varint,4,opt,name=TotalPorts,proto3" json:"TotalPorts,omitempty"` -} - -func (m *MultiPortDetails) Reset() { *m = MultiPortDetails{} } -func (m *MultiPortDetails) String() string { return proto.CompactTextString(m) } -func (*MultiPortDetails) ProtoMessage() {} -func (*MultiPortDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7} -} -func (m *MultiPortDetails) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *MultiPortDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_MultiPortDetails.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *MultiPortDetails) XXX_Merge(src proto.Message) { - xxx_messageInfo_MultiPortDetails.Merge(m, src) -} -func (m *MultiPortDetails) XXX_Size() int { - return m.Size() -} -func (m *MultiPortDetails) XXX_DiscardUnknown() { - xxx_messageInfo_MultiPortDetails.DiscardUnknown(m) -} - -var xxx_messageInfo_MultiPortDetails proto.InternalMessageInfo - -func (m *MultiPortDetails) GetRxSupported() bool { - if m != nil { - return m.RxSupported - } - return false -} - -func (m *MultiPortDetails) GetTxSupported() bool { - if m != nil { - return m.TxSupported - } - return false -} - -func (m *MultiPortDetails) GetBasePort() uint32 { - if m != nil { - return m.BasePort - } - return 0 -} - -func (m *MultiPortDetails) GetTotalPorts() uint32 { - if m != nil { - return m.TotalPorts - } - return 0 -} - -type NebulaHandshakeDetails struct { - Cert []byte `protobuf:"bytes,1,opt,name=Cert,proto3" json:"Cert,omitempty"` - InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,proto3" json:"InitiatorIndex,omitempty"` - ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"` - Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"` - Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"` - CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"` - InitiatorMultiPort *MultiPortDetails `protobuf:"bytes,6,opt,name=InitiatorMultiPort,proto3" json:"InitiatorMultiPort,omitempty"` - ResponderMultiPort *MultiPortDetails `protobuf:"bytes,7,opt,name=ResponderMultiPort,proto3" json:"ResponderMultiPort,omitempty"` -} - -func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} } -func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } -func (*NebulaHandshakeDetails) ProtoMessage() {} -func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{8} -} -func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *NebulaHandshakeDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_NebulaHandshakeDetails.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *NebulaHandshakeDetails) XXX_Merge(src proto.Message) { - xxx_messageInfo_NebulaHandshakeDetails.Merge(m, src) -} -func (m *NebulaHandshakeDetails) XXX_Size() int { - return m.Size() -} -func (m *NebulaHandshakeDetails) XXX_DiscardUnknown() { - xxx_messageInfo_NebulaHandshakeDetails.DiscardUnknown(m) -} - -var xxx_messageInfo_NebulaHandshakeDetails proto.InternalMessageInfo - -func (m *NebulaHandshakeDetails) GetCert() []byte { - if m != nil { - return m.Cert - } - return nil -} - -func (m *NebulaHandshakeDetails) GetInitiatorIndex() uint32 { - if m != nil { - return m.InitiatorIndex - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetResponderIndex() uint32 { - if m != nil { - return m.ResponderIndex - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetCookie() uint64 { - if m != nil { - return m.Cookie - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetTime() uint64 { - if m != nil { - return m.Time - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetCertVersion() uint32 { - if m != nil { - return m.CertVersion - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetInitiatorMultiPort() *MultiPortDetails { - if m != nil { - return m.InitiatorMultiPort - } - return nil -} - -func (m *NebulaHandshakeDetails) GetResponderMultiPort() *MultiPortDetails { - if m != nil { - return m.ResponderMultiPort - } - return nil -} - type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` @@ -723,7 +503,7 @@ func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{9} + return fileDescriptor_2d65afa7693df5ef, []int{6} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -813,70 +593,55 @@ func init() { proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") - proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") - proto.RegisterType((*MultiPortDetails)(nil), "nebula.MultiPortDetails") - proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") proto.RegisterType((*NebulaControl)(nil), "nebula.NebulaControl") } func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 864 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x56, 0x4f, 0x6f, 0xe3, 0x44, - 0x14, 0x8f, 0x1d, 0xe7, 0x4f, 0x5f, 0x9a, 0xac, 0x79, 0x15, 0x25, 0x5d, 0x89, 0x28, 0xf8, 0x50, - 0xad, 0x38, 0x64, 0x51, 0x5b, 0x56, 0x1c, 0xd9, 0x06, 0xa1, 0xac, 0xb4, 0xed, 0x96, 0x21, 0x14, - 0x89, 0x0b, 0x9a, 0xc6, 0x43, 0x63, 0xc5, 0xf1, 0x78, 0xed, 0x31, 0x6a, 0xbe, 0x05, 0xe2, 0xb3, - 0xf0, 0x21, 0xe0, 0xb6, 0x47, 0x4e, 0x08, 0xb5, 0x47, 0x8e, 0x7c, 0x01, 0x34, 0xe3, 0x7f, 0xe3, - 0xc4, 0x6c, 0x6f, 0xf3, 0xde, 0xef, 0xf7, 0x7b, 0xfe, 0xcd, 0x9b, 0x79, 0x93, 0xc0, 0x7e, 0xc0, - 0x6e, 0x12, 0x9f, 0x4e, 0xc2, 0x88, 0x0b, 0x8e, 0xed, 0x34, 0x72, 0xfe, 0x31, 0x01, 0x2e, 0xd5, - 0xf2, 0x82, 0x09, 0x8a, 0x27, 0x60, 0xcd, 0x37, 0x21, 0x1b, 0x1a, 0x63, 0xe3, 0xd9, 0xe0, 0x64, - 0x34, 0xc9, 0x34, 0x25, 0x63, 0x72, 0xc1, 0xe2, 0x98, 0xde, 0x32, 0xc9, 0x22, 0x8a, 0x8b, 0xa7, - 0xd0, 0xf9, 0x8a, 0x09, 0xea, 0xf9, 0xf1, 0xd0, 0x1c, 0x1b, 0xcf, 0x7a, 0x27, 0x47, 0xbb, 0xb2, - 0x8c, 0x40, 0x72, 0xa6, 0xf3, 0xaf, 0x01, 0x3d, 0xad, 0x14, 0x76, 0xc1, 0xba, 0xe4, 0x01, 0xb3, - 0x1b, 0xd8, 0x87, 0xbd, 0x19, 0x8f, 0xc5, 0x37, 0x09, 0x8b, 0x36, 0xb6, 0x81, 0x08, 0x83, 0x22, - 0x24, 0x2c, 0xf4, 0x37, 0xb6, 0x89, 0x4f, 0xe1, 0x50, 0xe6, 0xbe, 0x0b, 0x5d, 0x2a, 0xd8, 0x25, - 0x17, 0xde, 0x4f, 0xde, 0x82, 0x0a, 0x8f, 0x07, 0x76, 0x13, 0x8f, 0xe0, 0x43, 0x89, 0x5d, 0xf0, - 0x9f, 0x99, 0x5b, 0x81, 0xac, 0x1c, 0xba, 0x4a, 0x82, 0xc5, 0xb2, 0x02, 0xb5, 0x70, 0x00, 0x20, - 0xa1, 0xef, 0x97, 0x9c, 0xae, 0x3d, 0xbb, 0x8d, 0x07, 0xf0, 0xa4, 0x8c, 0xd3, 0xcf, 0x76, 0xa4, - 0xb3, 0x2b, 0x2a, 0x96, 0xd3, 0x25, 0x5b, 0xac, 0xec, 0xae, 0x74, 0x56, 0x84, 0x29, 0x65, 0x0f, - 0x3f, 0x86, 0xa3, 0x7a, 0x67, 0x2f, 0x17, 0x2b, 0x1b, 0x9c, 0x3f, 0x4c, 0xf8, 0x60, 0xa7, 0x29, - 0xe8, 0x00, 0xbc, 0xf1, 0xdd, 0xeb, 0x30, 0x78, 0xe9, 0xba, 0x91, 0x6a, 0x7d, 0xff, 0xdc, 0x1c, - 0x1a, 0x44, 0xcb, 0xe2, 0x31, 0x74, 0x72, 0x42, 0x5b, 0x35, 0x79, 0x3f, 0x6f, 0xb2, 0xcc, 0x91, - 0x1c, 0xc4, 0x09, 0xd8, 0x6f, 0x7c, 0x97, 0x30, 0x9f, 0x6e, 0xb2, 0x54, 0x3c, 0x6c, 0x8d, 0x9b, - 0x59, 0xc5, 0x1d, 0x0c, 0x4f, 0xa0, 0x5f, 0x25, 0x77, 0xc6, 0xcd, 0x9d, 0xea, 0x55, 0x0a, 0x9e, - 0x41, 0xef, 0xfa, 0x4c, 0x2e, 0xaf, 0x78, 0x24, 0xe4, 0xa1, 0x4b, 0x05, 0xe6, 0x8a, 0x12, 0x22, - 0x3a, 0x4d, 0xa9, 0x5e, 0x94, 0x2a, 0x6b, 0x4b, 0xf5, 0x42, 0x53, 0x95, 0x34, 0x1c, 0x42, 0x67, - 0xc1, 0x93, 0x40, 0xb0, 0x68, 0xd8, 0x94, 0x8d, 0x21, 0x79, 0xe8, 0x1c, 0x83, 0xa5, 0x76, 0x3c, - 0x00, 0x73, 0xe6, 0xa9, 0xae, 0x59, 0xc4, 0x9c, 0x79, 0x32, 0x7e, 0xcd, 0xd5, 0x4d, 0xb4, 0x88, - 0xf9, 0x9a, 0x3b, 0x67, 0x00, 0xa5, 0x0d, 0xc4, 0x54, 0x95, 0x76, 0x99, 0xa4, 0x15, 0x10, 0x2c, - 0x89, 0x29, 0x4d, 0x9f, 0xa8, 0xb5, 0xf3, 0x25, 0x40, 0x69, 0xe3, 0xb1, 0x6f, 0x14, 0x15, 0x9a, - 0x5a, 0x85, 0xbb, 0x7c, 0xb0, 0xae, 0xbc, 0xe0, 0xf6, 0xfd, 0x83, 0x25, 0x19, 0x35, 0x83, 0x85, - 0x60, 0xcd, 0xbd, 0x35, 0xcb, 0xbe, 0xa3, 0xd6, 0x8e, 0xb3, 0x33, 0x36, 0x52, 0x6c, 0x37, 0x70, - 0x0f, 0x5a, 0xe9, 0x25, 0x34, 0x9c, 0x1f, 0xe1, 0x49, 0x5a, 0x77, 0x46, 0x03, 0x37, 0x5e, 0xd2, - 0x15, 0xc3, 0x2f, 0xca, 0x19, 0x35, 0xd4, 0xf5, 0xd9, 0x72, 0x50, 0x30, 0xb7, 0x07, 0x55, 0x9a, - 0x98, 0xad, 0xe9, 0x42, 0x99, 0xd8, 0x27, 0x6a, 0xed, 0xfc, 0x6a, 0x80, 0x7d, 0x91, 0xf8, 0xc2, - 0x93, 0x1b, 0xcd, 0x89, 0x63, 0xe8, 0x91, 0xbb, 0x6f, 0x93, 0x30, 0xe4, 0x91, 0x60, 0xae, 0xfa, - 0x4c, 0x97, 0xe8, 0x29, 0xc9, 0x98, 0x6b, 0x0c, 0x33, 0x65, 0x68, 0x29, 0x7c, 0x0a, 0xdd, 0x73, - 0x1a, 0x33, 0xad, 0x97, 0x45, 0x8c, 0x23, 0x80, 0x39, 0x17, 0xd4, 0xcf, 0xaf, 0x8f, 0x44, 0xb5, - 0x8c, 0xf3, 0x97, 0x09, 0x87, 0xf5, 0x9b, 0x91, 0x7b, 0x98, 0xb2, 0x48, 0x28, 0x4f, 0xfb, 0x44, - 0xad, 0xf1, 0x18, 0x06, 0xaf, 0x02, 0x4f, 0x78, 0x54, 0xf0, 0xe8, 0x55, 0xe0, 0xb2, 0xbb, 0xec, - 0xf8, 0xb7, 0xb2, 0x92, 0x47, 0x58, 0x1c, 0xf2, 0xc0, 0x65, 0x19, 0x2f, 0x35, 0xb6, 0x95, 0xc5, - 0x43, 0x68, 0x4f, 0x39, 0x5f, 0x79, 0x4c, 0x59, 0xb3, 0x48, 0x16, 0x15, 0x87, 0xd8, 0x2a, 0x0f, - 0x51, 0x36, 0x42, 0x7a, 0xb8, 0x66, 0x51, 0xec, 0xf1, 0x60, 0xd8, 0x55, 0x05, 0xf5, 0x14, 0xce, - 0x00, 0x0b, 0x1f, 0x45, 0xa7, 0xb3, 0xc9, 0x1f, 0xe6, 0x47, 0xb7, 0x7d, 0x04, 0xa4, 0x46, 0x23, - 0x2b, 0x15, 0x4e, 0xcb, 0x4a, 0x9d, 0xc7, 0x2a, 0xed, 0x6a, 0x9c, 0xdf, 0x9a, 0xd0, 0x4f, 0x1b, - 0x3c, 0xe5, 0x81, 0x88, 0xb8, 0x8f, 0x9f, 0x57, 0x2e, 0xf5, 0x27, 0xd5, 0x2b, 0x95, 0x91, 0x6a, - 0xee, 0xf5, 0x67, 0x70, 0x50, 0x18, 0x55, 0x2f, 0x8b, 0xde, 0xff, 0x3a, 0x48, 0x2a, 0x0a, 0x43, - 0x9a, 0x22, 0x3d, 0x89, 0x3a, 0x08, 0x3f, 0x85, 0x41, 0xfe, 0xd6, 0xcd, 0xb9, 0x9a, 0x78, 0xab, - 0x78, 0x57, 0xb7, 0x10, 0xfd, 0xcd, 0xfc, 0x3a, 0xe2, 0x6b, 0xc5, 0x6e, 0x15, 0xec, 0x1d, 0x0c, - 0x27, 0xd0, 0xd3, 0x0b, 0xd7, 0xbd, 0xc7, 0x3a, 0xa1, 0x78, 0x63, 0x8b, 0xe2, 0x9d, 0x1a, 0x45, - 0x95, 0xe2, 0xcc, 0xfe, 0xef, 0xe7, 0xf1, 0x10, 0x70, 0x1a, 0x31, 0x2a, 0x98, 0xe2, 0x13, 0xf6, - 0x36, 0x61, 0xb1, 0xb0, 0x0d, 0xfc, 0x08, 0x0e, 0x2a, 0x79, 0xd9, 0x92, 0x98, 0xd9, 0xe6, 0xf9, - 0xe9, 0xef, 0xf7, 0x23, 0xe3, 0xdd, 0xfd, 0xc8, 0xf8, 0xfb, 0x7e, 0x64, 0xfc, 0xf2, 0x30, 0x6a, - 0xbc, 0x7b, 0x18, 0x35, 0xfe, 0x7c, 0x18, 0x35, 0x7e, 0x38, 0xba, 0xf5, 0xc4, 0x32, 0xb9, 0x99, - 0x2c, 0xf8, 0xfa, 0x79, 0xec, 0xd3, 0xc5, 0x6a, 0xf9, 0xf6, 0x79, 0x6a, 0xe9, 0xa6, 0xad, 0xfe, - 0x25, 0x9c, 0xfe, 0x17, 0x00, 0x00, 0xff, 0xff, 0xd6, 0x71, 0x5a, 0xf8, 0x35, 0x08, 0x00, 0x00, + // 665 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x54, 0xcd, 0x6e, 0xd3, 0x5c, + 0x10, 0x8d, 0x1d, 0x27, 0x69, 0x27, 0x4d, 0x3e, 0x7f, 0x53, 0x51, 0x12, 0x24, 0xac, 0xe0, 0x45, + 0x55, 0xb1, 0x48, 0x51, 0x5a, 0xba, 0xa6, 0x2d, 0x42, 0xa9, 0xd4, 0x9f, 0x70, 0x55, 0x8a, 0xc4, + 0xce, 0xb5, 0x2f, 0x8d, 0x55, 0xc7, 0x37, 0xb5, 0x6f, 0x50, 0xf3, 0x16, 0x3c, 0x0c, 0x0f, 0x01, + 0xbb, 0x2e, 0x59, 0xa2, 0x66, 0xc9, 0x92, 0x17, 0x40, 0xf7, 0xfa, 0xbf, 0x31, 0xb0, 0xbb, 0x33, + 0xe7, 0x9c, 0x99, 0xc9, 0xc9, 0x8c, 0x61, 0xcd, 0xa7, 0x97, 0x33, 0xcf, 0xea, 0x4f, 0x03, 0xc6, + 0x19, 0xd6, 0xa3, 0xc8, 0xfc, 0xa9, 0x02, 0x9c, 0xca, 0xe7, 0x09, 0xe5, 0x16, 0x0e, 0x40, 0x3b, + 0x9f, 0x4f, 0x69, 0x47, 0xe9, 0x29, 0x5b, 0xed, 0x81, 0xd1, 0x8f, 0x35, 0x19, 0xa3, 0x7f, 0x42, + 0xc3, 0xd0, 0xba, 0xa2, 0x82, 0x45, 0x24, 0x17, 0x77, 0xa0, 0xf1, 0x9a, 0x72, 0xcb, 0xf5, 0xc2, + 0x8e, 0xda, 0x53, 0xb6, 0x9a, 0x83, 0xee, 0xb2, 0x2c, 0x26, 0x90, 0x84, 0x69, 0xfe, 0x52, 0xa0, + 0x99, 0x2b, 0x85, 0x2b, 0xa0, 0x9d, 0x32, 0x9f, 0xea, 0x15, 0x6c, 0xc1, 0xea, 0x90, 0x85, 0xfc, + 0xed, 0x8c, 0x06, 0x73, 0x5d, 0x41, 0x84, 0x76, 0x1a, 0x12, 0x3a, 0xf5, 0xe6, 0xba, 0x8a, 0x4f, + 0x60, 0x43, 0xe4, 0xde, 0x4d, 0x1d, 0x8b, 0xd3, 0x53, 0xc6, 0xdd, 0x8f, 0xae, 0x6d, 0x71, 0x97, + 0xf9, 0x7a, 0x15, 0xbb, 0xf0, 0x48, 0x60, 0x27, 0xec, 0x13, 0x75, 0x0a, 0x90, 0x96, 0x40, 0xa3, + 0x99, 0x6f, 0x8f, 0x0b, 0x50, 0x0d, 0xdb, 0x00, 0x02, 0x7a, 0x3f, 0x66, 0xd6, 0xc4, 0xd5, 0xeb, + 0xb8, 0x0e, 0xff, 0x65, 0x71, 0xd4, 0xb6, 0x21, 0x26, 0x1b, 0x59, 0x7c, 0x7c, 0x38, 0xa6, 0xf6, + 0xb5, 0xbe, 0x22, 0x26, 0x4b, 0xc3, 0x88, 0xb2, 0x8a, 0x4f, 0xa1, 0x5b, 0x3e, 0xd9, 0xbe, 0x7d, + 0xad, 0x83, 0xf9, 0x4d, 0x85, 0xff, 0x97, 0x4c, 0x41, 0x13, 0xe0, 0xcc, 0x73, 0x2e, 0xa6, 0xfe, + 0xbe, 0xe3, 0x04, 0xd2, 0xfa, 0xd6, 0x81, 0xda, 0x51, 0x48, 0x2e, 0x8b, 0x9b, 0xd0, 0x48, 0x08, + 0x75, 0x69, 0xf2, 0x5a, 0x62, 0xb2, 0xc8, 0x91, 0x04, 0xc4, 0x3e, 0xe8, 0x67, 0x9e, 0x43, 0xa8, + 0x67, 0xcd, 0xe3, 0x54, 0xd8, 0xa9, 0xf5, 0xaa, 0x71, 0xc5, 0x25, 0x0c, 0x07, 0xd0, 0x2a, 0x92, + 0x1b, 0xbd, 0xea, 0x52, 0xf5, 0x22, 0x05, 0x77, 0xa1, 0x79, 0xb1, 0x2b, 0x9e, 0x23, 0x16, 0x70, + 0xf1, 0xa7, 0x0b, 0x05, 0x26, 0x8a, 0x0c, 0x22, 0x79, 0x9a, 0x54, 0xed, 0x65, 0x2a, 0xed, 0x81, + 0x6a, 0x2f, 0xa7, 0xca, 0x68, 0xd8, 0x81, 0x86, 0xcd, 0x66, 0x3e, 0xa7, 0x41, 0xa7, 0x2a, 0x8c, + 0x21, 0x49, 0x68, 0x6e, 0x82, 0x26, 0x7f, 0x71, 0x1b, 0xd4, 0xa1, 0x2b, 0x5d, 0xd3, 0x88, 0x3a, + 0x74, 0x45, 0x7c, 0xcc, 0xe4, 0x26, 0x6a, 0x44, 0x3d, 0x66, 0xe6, 0x2e, 0x40, 0x36, 0x06, 0x62, + 0xa4, 0x8a, 0x5c, 0x26, 0x51, 0x05, 0x04, 0x4d, 0x60, 0x52, 0xd3, 0x22, 0xf2, 0x6d, 0xbe, 0x02, + 0xc8, 0xc6, 0xf8, 0x57, 0x8f, 0xb4, 0x42, 0x35, 0x57, 0xe1, 0x36, 0x39, 0xac, 0x91, 0xeb, 0x5f, + 0xfd, 0xfd, 0xb0, 0x04, 0xa3, 0xe4, 0xb0, 0x10, 0xb4, 0x73, 0x77, 0x42, 0xe3, 0x3e, 0xf2, 0x6d, + 0x9a, 0x4b, 0x67, 0x23, 0xc4, 0x7a, 0x05, 0x57, 0xa1, 0x16, 0x2d, 0xa1, 0x62, 0x7e, 0xa9, 0x42, + 0x2b, 0x2a, 0x7c, 0xc8, 0x7c, 0x1e, 0x30, 0x0f, 0x5f, 0x16, 0xba, 0x3f, 0x2b, 0x76, 0x8f, 0x49, + 0x25, 0x03, 0xbc, 0x80, 0xf5, 0x23, 0xdf, 0xe5, 0xae, 0xc5, 0x59, 0x20, 0x57, 0xe0, 0xc8, 0x77, + 0xe8, 0x6d, 0xec, 0x53, 0x19, 0x24, 0x14, 0x84, 0x86, 0x53, 0xe6, 0x3b, 0x34, 0xaf, 0x88, 0x7c, + 0x29, 0x83, 0xf0, 0x39, 0xb4, 0x93, 0xa5, 0x3c, 0x67, 0xf2, 0xaf, 0xd1, 0xd2, 0x03, 0x78, 0x80, + 0xe4, 0x97, 0xfb, 0x4d, 0xc0, 0x26, 0x92, 0x5d, 0x4b, 0xd9, 0x4b, 0x18, 0xf6, 0xa1, 0x99, 0x2f, + 0x5c, 0x76, 0x38, 0x79, 0x42, 0x7a, 0x0c, 0x69, 0xf1, 0x46, 0x89, 0xa2, 0x48, 0x31, 0x87, 0x7f, + 0xfa, 0x8e, 0x6d, 0x00, 0x1e, 0x06, 0xd4, 0xe2, 0x54, 0xf2, 0x09, 0xbd, 0x99, 0xd1, 0x90, 0xeb, + 0x0a, 0x3e, 0x86, 0xf5, 0x42, 0x5e, 0x58, 0x12, 0x52, 0x5d, 0x3d, 0xd8, 0xf9, 0x7a, 0x6f, 0x28, + 0x77, 0xf7, 0x86, 0xf2, 0xe3, 0xde, 0x50, 0x3e, 0x2f, 0x8c, 0xca, 0xdd, 0xc2, 0xa8, 0x7c, 0x5f, + 0x18, 0x95, 0x0f, 0xdd, 0x2b, 0x97, 0x8f, 0x67, 0x97, 0x7d, 0x9b, 0x4d, 0xb6, 0x43, 0xcf, 0xb2, + 0xaf, 0xc7, 0x37, 0xdb, 0xd1, 0x48, 0x97, 0x75, 0xf9, 0x39, 0xdf, 0xf9, 0x1d, 0x00, 0x00, 0xff, + 0xff, 0x51, 0x0a, 0xe3, 0xd7, 0xde, 0x05, 0x00, 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -1161,180 +926,6 @@ func (m *NebulaPing) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } -func (m *NebulaHandshake) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *NebulaHandshake) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *NebulaHandshake) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if len(m.Hmac) > 0 { - i -= len(m.Hmac) - copy(dAtA[i:], m.Hmac) - i = encodeVarintNebula(dAtA, i, uint64(len(m.Hmac))) - i-- - dAtA[i] = 0x12 - } - if m.Details != nil { - { - size, err := m.Details.MarshalToSizedBuffer(dAtA[:i]) - if err != nil { - return 0, err - } - i -= size - i = encodeVarintNebula(dAtA, i, uint64(size)) - } - i-- - dAtA[i] = 0xa - } - return len(dAtA) - i, nil -} - -func (m *MultiPortDetails) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *MultiPortDetails) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *MultiPortDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if m.TotalPorts != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.TotalPorts)) - i-- - dAtA[i] = 0x20 - } - if m.BasePort != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.BasePort)) - i-- - dAtA[i] = 0x18 - } - if m.TxSupported { - i-- - if m.TxSupported { - dAtA[i] = 1 - } else { - dAtA[i] = 0 - } - i-- - dAtA[i] = 0x10 - } - if m.RxSupported { - i-- - if m.RxSupported { - dAtA[i] = 1 - } else { - dAtA[i] = 0 - } - i-- - dAtA[i] = 0x8 - } - return len(dAtA) - i, nil -} - -func (m *NebulaHandshakeDetails) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *NebulaHandshakeDetails) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if m.CertVersion != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion)) - i-- - dAtA[i] = 0x40 - } - if m.ResponderMultiPort != nil { - { - size, err := m.ResponderMultiPort.MarshalToSizedBuffer(dAtA[:i]) - if err != nil { - return 0, err - } - i -= size - i = encodeVarintNebula(dAtA, i, uint64(size)) - } - i-- - dAtA[i] = 0x3a - } - if m.InitiatorMultiPort != nil { - { - size, err := m.InitiatorMultiPort.MarshalToSizedBuffer(dAtA[:i]) - if err != nil { - return 0, err - } - i -= size - i = encodeVarintNebula(dAtA, i, uint64(size)) - } - i-- - dAtA[i] = 0x32 - } - if m.Time != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Time)) - i-- - dAtA[i] = 0x28 - } - if m.Cookie != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Cookie)) - i-- - dAtA[i] = 0x20 - } - if m.ResponderIndex != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.ResponderIndex)) - i-- - dAtA[i] = 0x18 - } - if m.InitiatorIndex != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.InitiatorIndex)) - i-- - dAtA[i] = 0x10 - } - if len(m.Cert) > 0 { - i -= len(m.Cert) - copy(dAtA[i:], m.Cert) - i = encodeVarintNebula(dAtA, i, uint64(len(m.Cert))) - i-- - dAtA[i] = 0xa - } - return len(dAtA) - i, nil -} - func (m *NebulaControl) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) @@ -1541,80 +1132,6 @@ func (m *NebulaPing) Size() (n int) { return n } -func (m *NebulaHandshake) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.Details != nil { - l = m.Details.Size() - n += 1 + l + sovNebula(uint64(l)) - } - l = len(m.Hmac) - if l > 0 { - n += 1 + l + sovNebula(uint64(l)) - } - return n -} - -func (m *MultiPortDetails) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.RxSupported { - n += 2 - } - if m.TxSupported { - n += 2 - } - if m.BasePort != 0 { - n += 1 + sovNebula(uint64(m.BasePort)) - } - if m.TotalPorts != 0 { - n += 1 + sovNebula(uint64(m.TotalPorts)) - } - return n -} - -func (m *NebulaHandshakeDetails) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - l = len(m.Cert) - if l > 0 { - n += 1 + l + sovNebula(uint64(l)) - } - if m.InitiatorIndex != 0 { - n += 1 + sovNebula(uint64(m.InitiatorIndex)) - } - if m.ResponderIndex != 0 { - n += 1 + sovNebula(uint64(m.ResponderIndex)) - } - if m.Cookie != 0 { - n += 1 + sovNebula(uint64(m.Cookie)) - } - if m.Time != 0 { - n += 1 + sovNebula(uint64(m.Time)) - } - if m.InitiatorMultiPort != nil { - l = m.InitiatorMultiPort.Size() - n += 1 + l + sovNebula(uint64(l)) - } - if m.ResponderMultiPort != nil { - l = m.ResponderMultiPort.Size() - n += 1 + l + sovNebula(uint64(l)) - } - if m.CertVersion != 0 { - n += 1 + sovNebula(uint64(m.CertVersion)) - } - return n -} - func (m *NebulaControl) Size() (n int) { if m == nil { return 0 @@ -2431,505 +1948,6 @@ func (m *NebulaPing) Unmarshal(dAtA []byte) error { } return nil } -func (m *NebulaHandshake) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: NebulaHandshake: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: NebulaHandshake: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Details", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + msglen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.Details == nil { - m.Details = &NebulaHandshakeDetails{} - } - if err := m.Details.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Hmac", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Hmac = append(m.Hmac[:0], dAtA[iNdEx:postIndex]...) - if m.Hmac == nil { - m.Hmac = []byte{} - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipNebula(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLengthNebula - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func (m *MultiPortDetails) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: MultiPortDetails: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: MultiPortDetails: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RxSupported", wireType) - } - var v int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - v |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - m.RxSupported = bool(v != 0) - case 2: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field TxSupported", wireType) - } - var v int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - v |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - m.TxSupported = bool(v != 0) - case 3: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field BasePort", wireType) - } - m.BasePort = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.BasePort |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 4: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field TotalPorts", wireType) - } - m.TotalPorts = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.TotalPorts |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - default: - iNdEx = preIndex - skippy, err := skipNebula(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLengthNebula - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: NebulaHandshakeDetails: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: NebulaHandshakeDetails: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Cert", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Cert = append(m.Cert[:0], dAtA[iNdEx:postIndex]...) - if m.Cert == nil { - m.Cert = []byte{} - } - iNdEx = postIndex - case 2: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field InitiatorIndex", wireType) - } - m.InitiatorIndex = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.InitiatorIndex |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 3: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field ResponderIndex", wireType) - } - m.ResponderIndex = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.ResponderIndex |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 4: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Cookie", wireType) - } - m.Cookie = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Cookie |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 5: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) - } - m.Time = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Time |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 6: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field InitiatorMultiPort", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + msglen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.InitiatorMultiPort == nil { - m.InitiatorMultiPort = &MultiPortDetails{} - } - if err := m.InitiatorMultiPort.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - case 7: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field ResponderMultiPort", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + msglen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.ResponderMultiPort == nil { - m.ResponderMultiPort = &MultiPortDetails{} - } - if err := m.ResponderMultiPort.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - case 8: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType) - } - m.CertVersion = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.CertVersion |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - default: - iNdEx = preIndex - skippy, err := skipNebula(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLengthNebula - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} func (m *NebulaControl) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 diff --git a/nebula.proto b/nebula.proto index 6123f63c..7b44f473 100644 --- a/nebula.proto +++ b/nebula.proto @@ -60,29 +60,9 @@ message NebulaPing { uint64 Time = 2; } -message NebulaHandshake { - NebulaHandshakeDetails Details = 1; - bytes Hmac = 2; -} - -message MultiPortDetails { - bool RxSupported = 1; - bool TxSupported = 2; - uint32 BasePort = 3; - uint32 TotalPorts = 4; -} - -message NebulaHandshakeDetails { - bytes Cert = 1; - uint32 InitiatorIndex = 2; - uint32 ResponderIndex = 3; - uint64 Cookie = 4; - uint64 Time = 5; - uint32 CertVersion = 8; - - MultiPortDetails InitiatorMultiPort = 6; - MultiPortDetails ResponderMultiPort = 7; -} +// NebulaHandshake / NebulaHandshakeDetails moved to +// handshake/handshake.proto. The handshake package speaks that wire format +// directly via a hand-written encoder/decoder. message NebulaControl { enum MessageType { diff --git a/noise.go b/noise.go index 57990a79..0491da17 100644 --- a/noise.go +++ b/noise.go @@ -15,14 +15,12 @@ type endianness interface { var noiseEndianness endianness = binary.BigEndian type NebulaCipherState struct { - c noise.Cipher - //k [32]byte - //n uint64 + c cipher.AEAD } func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { - return &NebulaCipherState{c: s.Cipher()} - + x := s.Cipher() + return &NebulaCipherState{c: x.(cipher.AEAD)} } // EncryptDanger encrypts and authenticates a given payload. @@ -46,7 +44,7 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n nb[2] = 0 nb[3] = 0 noiseEndianness.PutUint64(nb[4:], n) - out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad) + out = s.c.Seal(out, nb, plaintext, ad) //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) return out, nil } else { @@ -61,7 +59,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb[2] = 0 nb[3] = 0 noiseEndianness.PutUint64(nb[4:], n) - return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad) + return s.c.Open(out, nb, ciphertext, ad) } else { return []byte{}, nil } @@ -69,7 +67,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, func (s *NebulaCipherState) Overhead() int { if s != nil { - return s.c.(cipher.AEAD).Overhead() + return s.c.Overhead() } return 0 } diff --git a/notboring.go b/notboring.go index c86b0bc3..f138a0a6 100644 --- a/notboring.go +++ b/notboring.go @@ -1,5 +1,4 @@ //go:build !boringcrypto -// +build !boringcrypto package nebula diff --git a/outside.go b/outside.go index 0a9767ae..9a538f52 100644 --- a/outside.go +++ b/outside.go @@ -1,15 +1,16 @@ package nebula import ( + "context" "encoding/binary" "errors" + "log/slog" "net/netip" "time" "github.com/google/gopacket/layers" "golang.org/x/net/ipv6" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "golang.org/x/net/ipv4" @@ -19,208 +20,239 @@ const ( minFwPacketLen = 4 ) +var ErrOutOfWindow = errors.New("out of window packet") + func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors + // TODO: record metrics for rx holepunch/punchy packets? if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Error while parsing inbound packet", + "from", via, + "error", err, + "packet", packet, + ) + } + } + return + } + + if h.Version != header.Version { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected header version received", "from", via) + } + return + } + + // Check before processing to see if this is a expected type/subtype + if !h.IsValidSubType() { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected packet received", "from", via) } return } - //l.Error("in packet ", header, packet[HeaderLen:]) if !via.IsRelayed { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Refusing to process double encrypted packet", "from", via) } return } } + // don't keep Rx metrics for message type, since you can see those in the tun metrics + if h.Type != header.Message { + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + } + + // Unencrypted packets + switch h.Type { + case header.Handshake: + f.handshakeManager.HandleIncoming(via, packet, h) + return + + case header.RecvError: + f.handleRecvError(via.UdpAddr, h) + return + } + + // Relay packets are special + isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay) + var hostinfo *HostInfo - // verify if we've seen this index before, otherwise respond to the handshake initiation - if h.Type == header.Message && h.Subtype == header.MessageRelay { + if isMessageRelay { hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) } else { hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } - var ci *ConnectionState - if hostinfo != nil { - ci = hostinfo.ConnectionState + // At this point we should have a valid existing tunnel, verify and send + // recvError if necessary + if hostinfo == nil || hostinfo.ConnectionState == nil { + if !via.IsRelayed { + f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) + } + return } + // All remaining packets are encrypted + ci := hostinfo.ConnectionState + if !ci.window.Check(f.l, h.MessageCounter) { + return + } + + // Relay packets are special + if isMessageRelay { + f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache) + + return + } + + out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + if err != nil { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Failed to decrypt packet", + "error", err, + "from", via, + "header", h, + ) + } + return + } + + // Roam before we respond + f.handleHostRoaming(hostinfo, via) + f.connectionManager.In(hostinfo) + switch h.Type { case header.Message: - if !f.handleEncrypted(ci, via, h) { - return - } - switch h.Subtype { case header.MessageNone: - if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { - return - } - case header.MessageRelay: - // The entire body is sent as AD, not encrypted. - // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. - // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's - // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice - // which will gracefully fail in the DecryptDanger call. - signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] - signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) - if err != nil { - return - } - // Successfully validated the thing. Get rid of the Relay header. - signedPayload = signedPayload[header.Len:] - // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, via) - // Track usage of both the HostInfo and the Relay for the received & authenticated packet - f.connectionManager.In(hostinfo) - f.connectionManager.RelayUsed(h.RemoteIndex) - - relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) - if !ok { - // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing - // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") - return - } - - switch relay.Type { - case TerminalType: - // If I am the target of this relay, process the unwrapped packet - // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - via = ViaSender{ - UdpAddr: via.UdpAddr, - relayHI: hostinfo, - remoteIdx: relay.RemoteIndex, - relay: relay, - IsRelayed: true, - } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) - return - case ForwardingType: - // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) - if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") - return - } - - // If that relay is Established, forward the payload through it - if targetRelay.State == Established { - switch targetRelay.Type { - case ForwardingType: - // Forward this packet through the relay tunnel - // Find the target HostInfo - f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) - return - case TerminalType: - hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") - } - } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") - return - } - } + f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) + return } case header.LightHouse: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt lighthouse packet") - return - } - //TODO: assert via is not relayed - lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) - - // Fallthrough to the bottom to record incoming traffic + lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f) case header.Test: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { + switch h.Subtype { + case header.TestReply: + // No-op, useful for the Roaming and connectionManager side-effects above + case header.TestRequest: + f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h) return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt test packet") - return - } - - if h.Subtype == header.TestRequest { - // This testRequest might be from TryPromoteBest, so we should roam - // to the new IP address before responding - f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) - } - - // Fallthrough to the bottom to record incoming traffic - - // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they - // are unauthenticated - - case header.Handshake: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(via, packet, h) - return - - case header.RecvError: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(via.UdpAddr, h) - return - case header.CloseTunnel: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - - hostinfo.logger(f.l).WithField("from", via). - Info("Close tunnel received, tearing down.") - + hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via) f.closeTunnel(hostinfo) - return case header.Control: - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt Control packet") - return - } - - f.relayManager.HandleControlMsg(hostinfo, d, f) + f.relayManager.HandleControlMsg(hostinfo, out, f) default: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h) + } +} + +func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { + // The entire body is sent as AD, not encrypted. + // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. + // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's + // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice + // which will gracefully fail in the DecryptDanger call. + signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] + signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] + var err error + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) + if err != nil { + return + } + // Successfully validated the thing. Get rid of the Relay header. + signedPayload = signedPayload[header.Len:] + // Pull the Roaming parts up here, and return in all call paths. + f.handleHostRoaming(hostinfo, via) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet + f.connectionManager.In(hostinfo) + f.connectionManager.RelayUsed(h.RemoteIndex) + + relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) + if !ok { + // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing + // its internal mapping. This should never happen. + hostinfo.logger(f.l).Error("HostInfo missing remote relay index", + "vpnAddrs", hostinfo.vpnAddrs, + "remoteIndex", h.RemoteIndex, + ) return } - f.handleHostRoaming(hostinfo, via) + switch relay.Type { + case TerminalType: + // If I am the target of this relay, process the unwrapped packet + // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. + via = ViaSender{ + UdpAddr: via.UdpAddr, + relayHI: hostinfo, + remoteIdx: relay.RemoteIndex, + relay: relay, + IsRelayed: true, + } + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + case ForwardingType: + // Find the target HostInfo relay object + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + if err != nil { + hostinfo.logger(f.l).Info("Failed to find target host info by ip", + "relayTo", relay.PeerAddr, + "error", err, + "hostinfo.vpnAddrs", hostinfo.vpnAddrs, + ) + return + } - f.connectionManager.In(hostinfo) + // If that relay is Established, forward the payload through it + if targetRelay.State == Established { + switch targetRelay.Type { + case ForwardingType: + // Forward this packet through the relay tunnel + // Find the target HostInfo + f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) + case TerminalType: + hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") + return + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type) + } + return + } + } else { + hostinfo.logger(f.l).Info("Unexpected target relay state", + "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], + "targetRelayState", targetRelay.State, + ) + return + } + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type) + } + } } // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote @@ -249,20 +281,27 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port()) } if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") - return - } - - if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). - Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). - Info("Host roamed to new udp ip/port.") + if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote", + "suppressSeconds", RoamingSuppressSeconds, + "udpAddr", hostinfo.remote, + "newAddr", via.UdpAddr, + ) + } + return + } + + hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.", + "udpAddr", hostinfo.remote, + "newAddr", via.UdpAddr, + ) hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote hostinfo.SetRemote(via.UdpAddr) @@ -270,23 +309,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { } -// handleEncrypted returns true if a packet should be processed, false otherwise -func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool { - // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect - if ci == nil { - if !via.IsRelayed { - f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) - } - return false - } - // If the window check fails, refuse to process the packet, but don't send a recv error - if !ci.window.Check(f.l, h.MessageCounter) { - return false - } - - return true -} - var ( ErrPacketTooShort = errors.New("packet is too short") ErrUnknownIPVersion = errors.New("packet is an unknown ip version") @@ -336,13 +358,29 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { proto := layers.IPProtocol(data[protoAt]) switch proto { - case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: + case layers.IPProtocolESP, layers.IPProtocolNoNextHeader: fp.Protocol = uint8(proto) fp.RemotePort = 0 fp.LocalPort = 0 fp.Fragment = false return nil + case layers.IPProtocolICMPv6: + if dataLen < offset+6 { + return ErrIPv6PacketTooShort + } + fp.Protocol = uint8(proto) + fp.LocalPort = 0 //incoming vs outgoing doesn't matter for icmpv6 + icmptype := data[offset+1] + switch icmptype { + case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply: + fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier + default: + fp.RemotePort = 0 + } + fp.Fragment = false + return nil + case layers.IPProtocolTCP, layers.IPProtocolUDP: if dataLen < offset+4 { return ErrIPv6PacketTooShort @@ -432,34 +470,38 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { // Accounting for a variable header length, do we have enough data for our src/dst tuples? minLen := ihl - if !fp.Fragment && fp.Protocol != firewall.ProtoICMP { - minLen += minFwPacketLen + if !fp.Fragment { + if fp.Protocol == firewall.ProtoICMP { + minLen += minFwPacketLen + 2 + } else { + minLen += minFwPacketLen + } } + if len(data) < minLen { return ErrIPv4InvalidHeaderLength } - // Firewall packets are locally oriented - if incoming { + if incoming { // Firewall packets are locally oriented fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { - fp.RemotePort = 0 - fp.LocalPort = 0 - } else { - fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) - } } else { fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { - fp.RemotePort = 0 - fp.LocalPort = 0 - } else { - fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) - } + } + + if fp.Fragment { + fp.RemotePort = 0 + fp.LocalPort = 0 + } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + fp.LocalPort = 0 //code would be uint16(data[ihl+1]) + } else if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } return nil @@ -473,34 +515,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] } if !hostinfo.ConnectionState.window.Update(f.l, mc) { - hostinfo.logger(f.l).WithField("header", h). - Debugln("dropping out of window packet") - return nil, errors.New("out of window packet") + return nil, ErrOutOfWindow } return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { - var err error - - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) +func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { + err := newPacket(out, true, fwPacket) if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") - return false - } - - err = newPacket(out, true, fwPacket) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("packet", out). - Warnf("Error while validating inbound packet") - return false - } - - if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - Debugln("dropping out of window packet") - return false + hostinfo.logger(f.l).Warn("Error while validating inbound packet", + "error", err, + "packet", out, + ) + return } dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) @@ -508,20 +536,19 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping inbound packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping inbound packet", + "fwPacket", fwPacket, + "reason", dropReason, + ) } - return false + return } - f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { - f.l.WithError(err).Error("Failed to write to tun") + f.l.Error("Failed to write to tun", "error", err) } - return true } func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { @@ -535,35 +562,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) _ = f.outside.WriteTo(b, endpoint) - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("index", index). - WithField("udpAddr", endpoint). - Debug("Recv error sent") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Recv error sent", + "index", index, + "udpAddr", endpoint, + ) } } func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if !f.acceptRecvErrorConfig.ShouldRecvError(addr) { - f.l.WithField("index", h.RemoteIndex). - WithField("udpAddr", addr). - Debug("Recv error received, ignoring") + f.l.Debug("Recv error received, ignoring", + "index", h.RemoteIndex, + "udpAddr", addr, + ) return } - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("index", h.RemoteIndex). - WithField("udpAddr", addr). - Debug("Recv error received") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Recv error received", + "index", h.RemoteIndex, + "udpAddr", addr, + ) } hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex) if hostinfo == nil { - f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap") + f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex) return } if hostinfo.remote.IsValid() && hostinfo.remote != addr { - f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) + f.l.Info("Someone spoofing recv_errors?", + "addr", addr, + "hostinfoRemote", hostinfo.remote, + ) return } diff --git a/outside_test.go b/outside_test.go index 38dbef62..042ccbb3 100644 --- a/outside_test.go +++ b/outside_test.go @@ -155,6 +155,7 @@ func Test_newPacket_v6(t *testing.T) { // next layer, missing length byte err = newPacket(buffer.Bytes()[:49], true, p) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + err = nil // A good ICMP packet ip = layers.IPv6{ @@ -165,20 +166,26 @@ func Test_newPacket_v6(t *testing.T) { DstIP: net.IPv6linklocalallnodes, } - icmp := layers.ICMPv6{} - - buffer.Clear() - err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp) - if err != nil { - panic(err) + icmp := layers.ICMPv6{ + TypeCode: layers.ICMPv6TypeEchoRequest, + Checksum: 0x1234, } - err = newPacket(buffer.Bytes(), true, p) - require.NoError(t, err) + buffer.Clear() + require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp)) + require.Error(t, newPacket(buffer.Bytes(), true, p)) + + buffer.Clear() + echo := layers.ICMPv6Echo{ + Identifier: 0xabcd, + SeqNumber: 1234, + } + require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp, &echo)) + require.NoError(t, newPacket(buffer.Bytes(), true, p)) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) - assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0xabcd), p.RemotePort) assert.Equal(t, uint16(0), p.LocalPort) assert.False(t, p.Fragment) @@ -574,7 +581,7 @@ func BenchmarkParseV6(b *testing.B) { } evilBytes := buffer.Bytes() - for i := 0; i < 200; i++ { + for range 200 { evilBytes = append(evilBytes, hopHeader...) } evilBytes = append(evilBytes, lastHopHeader...) diff --git a/test/tun.go b/overlay/overlaytest/noop.go similarity index 68% rename from test/tun.go rename to overlay/overlaytest/noop.go index fb32782f..956da7dd 100644 --- a/test/tun.go +++ b/overlay/overlaytest/noop.go @@ -1,4 +1,6 @@ -package test +// Package overlaytest provides fakes of overlay.Device for tests that do +// not want to touch a real tun device or route table. +package overlaytest import ( "errors" @@ -8,6 +10,9 @@ import ( "github.com/slackhq/nebula/routing" ) +// NoopTun is an overlay.Device that silently discards every read and write. +// Useful in tests that need to construct a nebula Interface but do not +// exercise the datapath. type NoopTun struct{} func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { diff --git a/overlay/route.go b/overlay/route.go index 61989581..c6403f91 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -2,6 +2,7 @@ package overlay import ( "fmt" + "log/slog" "math" "net" "net/netip" @@ -9,7 +10,6 @@ import ( "strconv" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) @@ -48,11 +48,14 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { +func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { routeTree := new(bart.Table[routing.Gateways]) for _, r := range routes { if !allowMTU && r.MTU > 0 { - l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) + l.Warn("route MTU is not supported on this platform", + "goos", runtime.GOOS, + "route", r, + ) } gateways := r.Via diff --git a/overlay/route_test.go b/overlay/route_test.go index 9a959a55..f9d9dcd9 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) { routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 2) - routeTree, err := makeRouteTree(l, routes, true) + routeTree, err := makeRouteTree(test.NewLogger(), routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("1.0.0.2") @@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) { routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 3) - routeTree, err := makeRouteTree(l, routes, true) + routeTree, err := makeRouteTree(test.NewLogger(), routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("192.168.86.1") diff --git a/overlay/tun.go b/overlay/tun.go index 3a61d186..3af1e189 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -2,20 +2,29 @@ package overlay import ( "fmt" + "log/slog" "net" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" ) const DefaultMTU = 1300 -// TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) +type NameError struct { + Name string + Underlying error +} -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func (e *NameError) Error() string { + return fmt.Sprintf("could not set tun device name: %s because %s", e.Name, e.Underlying) +} + +// TODO: We may be able to remove routines +type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) + +func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) @@ -27,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return newTunFromFd(c, l, *fd, vpnNetworks) } } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..9cbb64be 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -6,12 +6,12 @@ package overlay import ( "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -23,10 +23,10 @@ type tun struct { vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..524ef0cd 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" @@ -14,7 +15,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -30,7 +30,7 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr - l *logrus.Logger + l *slog.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte @@ -79,7 +79,7 @@ type ifreqAlias6 struct { Lifetime addrLifetime } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error { err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { - t.l.WithField("route", r.Cidr). - Warnf("unable to add unsafe_route, identical route already exists") + t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr) } else { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { @@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error { } } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.linkAddr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index aa3dddaf..f47880dd 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -1,13 +1,14 @@ package overlay import ( + "context" "fmt" "io" + "log/slog" "net/netip" "strings" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/routing" ) @@ -19,10 +20,10 @@ type disabledTun struct { // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter rx metrics.Counter - l *logrus.Logger + l *slog.Logger } -func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun { tun := &disabledTun{ vpnNetworks: vpnNetworks, read: make(chan []byte, queueLen), @@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) { } t.tx.Inc(1) - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Write payload", "raw", prettyPacket(r)) } return copy(b, r), nil @@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { select { case t.read <- out: default: - t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response") + t.l.Debug("tun_disabled: dropped ICMP Echo Reply response") } return true @@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) { // Check for ICMP Echo Request before spending time doing the full parsing if t.handleICMPEchoRequest(b) { - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b)) } - } else if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload") + } else if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b)) } return len(b), nil } diff --git a/overlay/tun_file_linux_test.go b/overlay/tun_file_linux_test.go new file mode 100644 index 00000000..5ab87e05 --- /dev/null +++ b/overlay/tun_file_linux_test.go @@ -0,0 +1,120 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package overlay + +import ( + "errors" + "os" + "sync" + "testing" + "time" + + "golang.org/x/sys/unix" +) + +// newReadPipe returns a read fd. The matching write fd is registered for cleanup. +// The caller takes ownership of the read fd (pass it to newTunFd / newFriend). +func newReadPipe(t *testing.T) int { + t.Helper() + var fds [2]int + if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil { + t.Fatalf("pipe2: %v", err) + } + t.Cleanup(func() { _ = unix.Close(fds[1]) }) + return fds[0] +} + +func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) { + tf, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + t.Cleanup(func() { _ = tf.Close() }) + + done := make(chan error, 1) + go func() { + _, err := tf.Read(make([]byte, 64)) + done <- err + }() + + // Verify Read is actually blocked in poll. + select { + case err := <-done: + t.Fatalf("Read returned before shutdown signal: %v", err) + case <-time.After(50 * time.Millisecond): + } + + if err := tf.wakeForShutdown(); err != nil { + t.Fatalf("wakeForShutdown: %v", err) + } + + select { + case err := <-done: + if !errors.Is(err, os.ErrClosed) { + t.Fatalf("expected os.ErrClosed, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Read did not wake on shutdown") + } +} + +func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) { + parent, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + friend, err := parent.newFriend(newReadPipe(t)) + if err != nil { + _ = parent.Close() + t.Fatalf("newFriend: %v", err) + } + t.Cleanup(func() { + _ = friend.Close() + _ = parent.Close() + }) + + readers := []*tunFile{parent, friend} + errs := make([]error, len(readers)) + var wg sync.WaitGroup + for i, r := range readers { + wg.Add(1) + go func(i int, r *tunFile) { + defer wg.Done() + _, errs[i] = r.Read(make([]byte, 64)) + }(i, r) + } + + time.Sleep(50 * time.Millisecond) + + if err := parent.wakeForShutdown(); err != nil { + t.Fatalf("wakeForShutdown: %v", err) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("readers did not wake") + } + + for i, err := range errs { + if !errors.Is(err, os.ErrClosed) { + t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err) + } + } +} + +func TestTunFile_Close_Idempotent(t *testing.T) { + tf, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("second Close should be a no-op, got %v", err) + } +} diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 8d292263..3d995553 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -9,15 +9,18 @@ import ( "fmt" "io" "io/fs" + "log/slog" "net/netip" + "os" "sync/atomic" "syscall" "time" "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -92,133 +95,232 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr - l *logrus.Logger - devFd int + l *slog.Logger + + fd int + shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls + shutdownW int // write end of the shutdown pipe; closing this signals shutdown to any blocked reader/writer + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed atomic.Bool +} + +// blockOnRead waits until the tun fd is readable or shutdown has been signaled. +// Returns os.ErrClosed if Close was called. +func (t *tun) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.readPoll[0].Revents + shutdownEvents := t.readPoll[1].Revents + t.readPoll[0].Revents = 0 + t.readPoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (t *tun) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.writePoll[0].Revents + shutdownEvents := t.writePoll[1].Revents + t.writePoll[0].Revents = 0 + t.writePoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil } func (t *tun) Read(to []byte) (int, error) { - // use readv() to read from the tunnel device, to eliminate the need for copying the buffer - if t.devFd < 0 { - return -1, syscall.EINVAL - } - // first 4 bytes is protocol family, in network byte order - head := make([]byte, 4) - - iovecs := []syscall.Iovec{ + var head [4]byte + iovecs := [2]syscall.Iovec{ {&head[0], 4}, {&to[0], uint64(len(to))}, } - - n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) - - var err error - if errno != 0 { - err = syscall.Errno(errno) - } else { - err = nil - } - // fix bytes read number to exclude header - bytesRead := int(n) - if bytesRead < 0 { - return bytesRead, err - } else if bytesRead < 4 { - return 0, err - } else { - return bytesRead - 4, err + for { + n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) + if errno == 0 { + bytesRead := int(n) + if bytesRead < 4 { + return 0, nil + } + return bytesRead - 4, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnRead(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } } } // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { - // use writev() to write to the tunnel device, to eliminate the need for copying the buffer - if t.devFd < 0 { - return -1, syscall.EINVAL - } - if len(from) <= 1 { return 0, syscall.EIO } + ipVer := from[0] >> 4 - var head []byte + var head [4]byte // first 4 bytes is protocol family, in network byte order - if ipVer == 4 { - head = []byte{0, 0, 0, syscall.AF_INET} - } else if ipVer == 6 { - head = []byte{0, 0, 0, syscall.AF_INET6} - } else { + switch ipVer { + case 4: + head[3] = syscall.AF_INET + case 6: + head[3] = syscall.AF_INET6 + default: return 0, fmt.Errorf("unable to determine IP version from packet") } - iovecs := []syscall.Iovec{ + + iovecs := [2]syscall.Iovec{ {&head[0], 4}, {&from[0], uint64(len(from))}, } - - n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) - - var err error - if errno != 0 { - err = syscall.Errno(errno) - } else { - err = nil + for { + n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) + if errno == 0 { + return int(n) - 4, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnWrite(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } } - - return int(n) - 4, err } func (t *tun) Close() error { - if t.devFd >= 0 { - err := syscall.Close(t.devFd) + if t.closed.Swap(true) { + return nil + } + + // Closing the write end of the shutdown pipe causes any blocked Poll to + // return with POLLHUP on the shutdown fd, so readers/writers wake up and + // exit with os.ErrClosed. + if t.shutdownW >= 0 { + _ = unix.Close(t.shutdownW) + t.shutdownW = -1 + } + + if t.fd >= 0 { + if err := unix.Close(t.fd); err != nil { + t.l.Error("Error closing device", "error", err) + } + t.fd = -1 + } + + if t.shutdownR >= 0 { + _ = unix.Close(t.shutdownR) + t.shutdownR = -1 + } + + c := make(chan struct{}) + go func() { + // destroying the interface can block if a read() is still pending. Do this asynchronously. + defer close(c) + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err == nil { + defer syscall.Close(s) + ifreq := ifreqDestroy{Name: t.deviceBytes()} + err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) + } if err != nil { - t.l.WithError(err).Error("Error closing device") + t.l.Error("Error destroying tunnel", "error", err) } - t.devFd = -1 + }() - c := make(chan struct{}) - go func() { - // destroying the interface can block if a read() is still pending. Do this asynchronously. - defer close(c) - s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) - if err == nil { - defer syscall.Close(s) - ifreq := ifreqDestroy{Name: t.deviceBytes()} - err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) - } - if err != nil { - t.l.WithError(err).Error("Error destroying tunnel") - } - }() - - // wait up to 1 second so we start blocking at the ioctl - select { - case <-c: - case <-time.After(1 * time.Second): - } + // wait up to 1 second so we start blocking at the ioctl + select { + case <-c: + case <-time.After(1 * time.Second): } return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var fd int var err error deviceName := c.GetString("tun.dev", "") if deviceName != "" { - fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0) + fd, err = unix.Open("/dev/"+deviceName, os.O_RDWR, 0) } if errors.Is(err, fs.ErrNotExist) || deviceName == "" { // If the device doesn't already exist, request a new one and rename it - fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0) + fd, err = unix.Open("/dev/tun", os.O_RDWR, 0) } if err != nil { return nil, err } + if err = unix.SetNonblock(fd, true); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to set tun device as nonblocking: %w", err) + } + + // Shutdown pipe lets Close wake any reader/writer blocked in Poll. + var pipeFds [2]int + if err = unix.Pipe2(pipeFds[:], unix.O_CLOEXEC|unix.O_NONBLOCK); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to create shutdown pipe: %w", err) + } + shutdownR, shutdownW := pipeFds[0], pipeFds[1] + + closeOnErr := true + defer func() { + if closeOnErr { + _ = unix.Close(fd) + _ = unix.Close(shutdownR) + _ = unix.Close(shutdownW) + } + }() + // Read the name of the interface var name [16]byte arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)} @@ -237,7 +339,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } if ctrlErr != nil { - return nil, err + return nil, ctrlErr } ifName := string(bytes.TrimRight(name[:], "\x00")) @@ -253,8 +355,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } defer syscall.Close(s) - fd := uintptr(s) - var fromName [16]byte var toName [16]byte copy(fromName[:], ifName) @@ -266,7 +366,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } // Set the device name - ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) + _ = ioctl(uintptr(s), syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) } t := &tun{ @@ -274,13 +374,24 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, - devFd: fd, + fd: fd, + shutdownR: shutdownR, + shutdownW: shutdownW, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownR), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownR), Events: unix.POLLIN}, + }, } err = t.reload(c, true) if err != nil { return nil, err } + closeOnErr = false c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) @@ -475,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -490,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.linkAddr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..6bfcbdfb 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "sync" @@ -14,7 +15,6 @@ import ( "syscall" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -25,14 +25,14 @@ type tun struct { vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index ea666f86..c6cfb686 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,8 +4,10 @@ package overlay import ( + "encoding/binary" "fmt" "io" + "log/slog" "net" "net/netip" "os" @@ -16,7 +18,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -24,9 +25,175 @@ import ( "golang.org/x/sys/unix" ) +// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. +// A shared eventfd allows Close to wake all readers blocked in poll. +type tunFile struct { + fd int + shutdownFd int + lastOne bool + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed bool +} + +// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun +func (r *tunFile) newFriend(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + return &tunFile{ + fd: fd, + shutdownFd: r.shutdownFd, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, + }, + }, nil +} + +func newTunFd(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &tunFile{ + fd: fd, + shutdownFd: shutdownFd, + lastOne: true, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + } + + return out, nil +} + +func (r *tunFile) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.readPoll[0].Revents + shutdownEvents := r.readPoll[1].Revents + r.readPoll[0].Revents = 0 + r.readPoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.writePoll[0].Revents + shutdownEvents := r.writePoll[1].Revents + r.writePoll[0].Revents = 0 + r.writePoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) Read(buf []byte) (int, error) { + for { + if n, err := unix.Read(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnRead(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } +} + +func (r *tunFile) Write(buf []byte) (int, error) { + for { + if n, err := unix.Write(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnWrite(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } +} + +func (r *tunFile) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(int(r.readPoll[1].Fd), buf[:]) + return err +} + +func (r *tunFile) Close() error { + if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem + return nil + } + r.closed = true + if r.lastOne { + _ = unix.Close(r.shutdownFd) + } + return unix.Close(r.fd) +} + type tun struct { - io.ReadWriteCloser - fd int + *tunFile + readers []*tunFile + closeLock sync.Mutex Device string vpnNetworks []netip.Prefix MaxMTU int @@ -46,7 +213,7 @@ type tun struct { routesFromSystem map[netip.Prefix]routing.Gateways routesFromSystemLock sync.Mutex - l *logrus.Logger + l *slog.Logger } func (t *tun) Networks() []netip.Prefix { @@ -71,10 +238,8 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - - t, err := newTunGeneric(c, l, file, vpnNetworks) +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { + t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) if err != nil { return nil, err } @@ -84,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -112,14 +277,18 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } - copy(req.Name[:], c.GetString("tun.dev", "")) + nameStr := c.GetString("tun.dev", "") + copy(req.Name[:], nameStr) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { - return nil, err + _ = unix.Close(fd) + return nil, &NameError{ + Name: nameStr, + Underlying: err, + } } name := strings.Trim(string(req.Name[:]), "\x00") - file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, fd, vpnNetworks) if err != nil { return nil, err } @@ -129,10 +298,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. +func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { + tfd, err := newTunFd(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } t := &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), + tunFile: tfd, + readers: []*tunFile{tfd}, + closeLock: sync.Mutex{}, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), @@ -141,8 +317,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n l: l, } - err := t.reload(c, true) - if err != nil { + if err = t.reload(c, true); err != nil { + _ = t.Close() return nil, err } @@ -202,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error { if !initial { if oldMaxMTU != newMaxMTU { t.setMTU() - t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU) + t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU) } if oldDefaultMTU != newDefaultMTU { for i := range t.vpnNetworks { err := t.setDefaultRoute(t.vpnNetworks[i]) if err != nil { - t.l.Warn(err) + t.l.Warn(err.Error()) } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU) } } } @@ -235,6 +411,9 @@ func (t *tun) SupportsMultiqueue() bool { } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + t.closeLock.Lock() + defer t.closeLock.Unlock() + fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -244,12 +423,19 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + _ = unix.Close(fd) return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") + out, err := t.tunFile.newFriend(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } - return file, nil + t.readers = append(t.readers, out) + + return out, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -257,29 +443,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -func (t *tun) Write(b []byte) (int, error) { - var nn int - maximum := len(b) - - for { - n, err := unix.Write(t.fd, b[nn:maximum]) - if n > 0 { - nn += n - } - if nn == len(b) { - return nn, err - } - - if err != nil { - return nn, err - } - - if n == 0 { - return nn, io.ErrUnexpectedEOF - } - } -} - func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) @@ -329,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error { } err = netlink.AddrDel(link, &al[i]) if err != nil { - t.l.WithError(err).Error("failed to remove address from tun address list") + t.l.Error("failed to remove address from tun address list", "error", err) } else { - t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + t.l.Info("removed address not listed in cert(s)", "removed", al[i].String()) } } @@ -375,12 +538,12 @@ func (t *tun) Activate() error { ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { // If we can't set the queue length nebula will still work but it may lead to packet loss - t.l.WithError(err).Error("Failed to set tun tx queue length") + t.l.Error("Failed to set tun tx queue length", "error", err) } const modeNone = 1 if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil { - t.l.WithError(err).Warn("Failed to disable link local address generation") + t.l.Warn("Failed to disable link local address generation", "error", err) } if err = t.addIPs(link); err != nil { @@ -419,7 +582,7 @@ func (t *tun) setMTU() { ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)} if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { // This is currently a non fatal condition because the route table must have the MTU set appropriately as well - t.l.WithError(err).Error("Failed to set tun mtu") + t.l.Error("Failed to set tun mtu", "error", err) } } @@ -442,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { } err := netlink.RouteReplace(&nr) if err != nil { - t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") + t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr) //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` for i := 0; i < 2; i++ { time.Sleep(100 * time.Millisecond) @@ -450,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { if err == nil { break } else { - t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") + t.l.Warn("Failed to set default route MTU, retrying", + "error", err, + "cidr", cidr, + "mtu", t.DefaultMTU, + ) } } if err != nil { @@ -495,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -527,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) { err := netlink.RouteDel(&nr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } } @@ -558,11 +725,11 @@ func (t *tun) watchRoutes() { netlinkOptions := netlink.RouteSubscribeOptions{ ReceiveBufferSize: t.useSystemRoutesBufferSize, ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0, - ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") }, + ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) }, } if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil { - t.l.WithError(err).Errorf("failed to subscribe to system route changes") + t.l.Error("failed to subscribe to system route changes", "error", err) return } @@ -604,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { link, err := netlink.LinkByName(t.Device) if err != nil { - t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") + t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device) return gateways } @@ -616,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { gateways = append(gateways, routing.NewGateway(gwAddr, 1)) } else { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") + t.l.Debug("Ignoring route update, gateway is not in our network", "route", r) } } else { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") + t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r) } } @@ -632,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) } else { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") + t.l.Debug("Ignoring route update, gateway is not in our network", "route", r) } } else { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") + t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r) } } } @@ -667,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { gateways := t.getGatewaysFromRoute(&r.Route) if len(gateways) == 0 { // No gateways relevant to our network, no routing changes required. - t.l.WithField("route", r).Debug("Ignoring route update, no gateways") + t.l.Debug("Ignoring route update, no gateways", "route", r) return } if r.Dst == nil { - t.l.WithField("route", r).Debug("Ignoring route update, no destination address") + t.l.Debug("Ignoring route update, no destination address", "route", r) return } dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) if !ok { - t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") + t.l.Debug("Ignoring route update, invalid destination address", "route", r) return } @@ -689,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { t.routesFromSystemLock.Lock() if r.Type == unix.RTM_NEWROUTE { - t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") + t.l.Info("Adding route", "destination", dst, "via", gateways) t.routesFromSystem[dst] = gateways newTree.Insert(dst, gateways) } else { - t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") + t.l.Info("Removing route", "destination", dst, "via", gateways) delete(t.routesFromSystem, dst) newTree.Delete(dst) } @@ -703,17 +870,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } func (t *tun) Close() error { + t.closeLock.Lock() + defer t.closeLock.Unlock() + if t.routeChan != nil { close(t.routeChan) + t.routeChan = nil } - if t.ReadWriteCloser != nil { - _ = t.ReadWriteCloser.Close() - } + // Signal all readers blocked in poll to wake up and exit + _ = t.tunFile.wakeForShutdown() if t.ioctlFd > 0 { - _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() + _ = unix.Close(int(t.ioctlFd)) + t.ioctlFd = 0 } - return nil + for i := range t.readers { + if i == 0 { + continue //we want to close the zeroth reader last + } + err := t.readers[i].Close() + if err != nil { + t.l.Error("error closing tun reader", "reader", i, "error", err) + } else { + t.l.Info("closed tun reader", "reader", i) + } + } + + //this is t.readers[0] too + err := t.tunFile.Close() + if err != nil { + t.l.Error("error closing tun reader", "reader", 0, "error", err) + } else { + t.l.Info("closed tun reader", "reader", 0) + } + return err } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..c971bb6e 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "regexp" @@ -15,7 +16,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -63,18 +63,18 @@ type tun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger f *os.File fd int } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( err = unix.SetNonblock(fd, true) if err != nil { - l.WithError(err).Warn("Failed to set the tun device as nonblocking") + l.Warn("Failed to set the tun device as nonblocking", "error", err) } t := &tun{ @@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..81362184 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "regexp" @@ -15,7 +16,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -54,7 +54,7 @@ type tun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger f *os.File fd int // cache out buffer since we need to prepend 4 bytes for tun metadata @@ -63,11 +63,11 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( err = unix.SetNonblock(fd, true) if err != nil { - l.WithError(err).Warn("Failed to set the tun device as nonblocking") + l.Warn("Failed to set the tun device as nonblocking", "error", err) } t := &tun{ @@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..8acd83f0 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -4,16 +4,18 @@ package overlay import ( + "context" "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" + "github.com/slackhq/nebula/udp" ) type TestTun struct { @@ -21,14 +23,14 @@ type TestTun struct { vpnNetworks []netip.Prefix Routes []Route routeTree *bart.Table[routing.Gateways] - l *logrus.Logger + l *slog.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err @@ -49,22 +51,27 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } -// Send will place a byte array onto the receive queue for nebula to consume +// Send will place a byte array onto the receive queue for nebula to consume. // These are unencrypted ip layer frames destined for another nebula node. -// packets should exit the udp side, capture them with udpConn.Get +// packets should exit the udp side, capture them with udpConn.Get. +// +// Send copies the input via the freelist, so the caller is free to mutate +// or reuse it after the call returns. func (t *TestTun) Send(packet []byte) { if t.closed.Load() { return } - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Tun receiving injected packet", "dataLen", len(packet)) } - t.rxPackets <- packet + buf := acquireTunBuf(len(packet)) + copy(buf, packet) + t.rxPackets <- buf } // Get will pull an unencrypted ip layer frame from the transmit queue @@ -109,12 +116,44 @@ func (t *TestTun) Write(b []byte) (n int, err error) { return 0, io.ErrClosedPipe } - packet := make([]byte, len(b), len(b)) + packet := acquireTunBuf(len(b)) copy(packet, b) t.TxPackets <- packet return len(b), nil } +// ReleaseTunBuf returns a slice from TxPackets to the harness freelist, don't use the bytes after the call. +// Channel-backed instead of sync.Pool because putting a []byte in a sync.Pool escapes the slice header to heap. +func ReleaseTunBuf(b []byte) { + if b == nil { + return + } + select { + case tunBufFreelist <- b: + default: + // Freelist full; drop the buffer for the GC. + } +} + +// tunBufFreelist retains the backing arrays for TestTun.Write so steady-state allocation drops to zero once the +// freelist has saturated for the current MTU. +var tunBufFreelist = make(chan []byte, 64) + +func acquireTunBuf(n int) []byte { + var b []byte + select { + case b = <-tunBufFreelist: + default: + b = make([]byte, 0, udp.MTU) + } + if cap(b) < n { + b = make([]byte, n) + } else { + b = b[:n] + } + return b +} + func (t *TestTun) Close() error { if t.closed.CompareAndSwap(false, true) { close(t.rxPackets) @@ -128,8 +167,14 @@ func (t *TestTun) Read(b []byte) (int, error) { if !ok { return 0, os.ErrClosed } + n := len(p) copy(b, p) - return len(p), nil + // Send always pushes a freelist-acquired slice, return it once we've copied the bytes into the caller's buffer. + select { + case tunBufFreelist <- p: + default: + } + return n, nil } func (t *TestTun) SupportsMultiqueue() bool { diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index b4d78b66..680dddb3 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -7,6 +7,7 @@ import ( "crypto" "fmt" "io" + "log/slog" "net/netip" "os" "path/filepath" @@ -16,7 +17,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -33,16 +33,16 @@ type winTun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger tun *wintun.NativeTun } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { err := checkWinTunExists() if err != nil { return nil, fmt.Errorf("can not load the wintun driver: %w", err) @@ -71,10 +71,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( if err != nil { // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. // Trying a second time resolves the issue. - l.WithError(err).Debug("Failed to create wintun device, retrying") + l.Debug("Failed to create wintun device, retrying", "error", err) tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { - return nil, fmt.Errorf("create TUN device failed: %w", err) + return nil, &NameError{ + Name: deviceName, + Underlying: fmt.Errorf("create TUN device failed: %w", err), + } } } t.tun = tunDevice.(*wintun.NativeTun) @@ -167,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } if !foundDefault4 { @@ -205,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error { // See comment on luid.AddRoute err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..e5f27f37 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -2,14 +2,14 @@ package overlay import ( "io" + "log/slog" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return NewUserDevice(vpnNetworks) } diff --git a/pki.go b/pki.go index 19869d58..acc80486 100644 --- a/pki.go +++ b/pki.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "log/slog" "net" "net/netip" "os" @@ -13,25 +15,27 @@ import ( "sync/atomic" "time" + "github.com/flynn/noise" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/handshake" + "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/util" ) type PKI struct { cs atomic.Pointer[CertState] caPool atomic.Pointer[cert.CAPool] - l *logrus.Logger + l *slog.Logger } type CertState struct { - v1Cert cert.Certificate - v1HandshakeBytes []byte + v1Cert cert.Certificate + v1Credential *handshake.Credential - v2Cert cert.Certificate - v2HandshakeBytes []byte + v2Cert cert.Certificate + v2Credential *handshake.Credential initiatingVersion cert.Version privateKey []byte @@ -45,7 +49,7 @@ type CertState struct { myVpnBroadcastAddrsTable *bart.Lite } -func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { +func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) { pki := &PKI{l: l} err := pki.reload(c, true) if err != nil { @@ -91,13 +95,35 @@ func (p *PKI) reload(c *config.C, initial bool) error { } func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { - newState, err := newCertStateFromConfig(c) + var cipher string + var currentState *CertState + if initial { + cipher = c.GetString("cipher", "aes") + //TODO: this sucks and we should make it not a global + switch cipher { + case "aes": + noiseEndianness = binary.BigEndian + case "chachapoly": + noiseEndianness = binary.LittleEndian + default: + return util.NewContextualError( + "unknown cipher", + m{"cipher": cipher}, + nil, + ) + } + } else { + // Cipher cant be hot swapped so just leave it at what it was before + currentState = p.cs.Load() + cipher = currentState.cipher + } + + newState, err := newCertStateFromConfig(c, cipher) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } - if !initial { - currentState := p.cs.Load() + if currentState != nil { if newState.v1Cert != nil { if currentState.v1Cert == nil { //adding certs is fine, actually. Networks-in-common confirmed in newCertState(). @@ -157,33 +183,14 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { ) } } - - // Cipher cant be hot swapped so just leave it at what it was before - newState.cipher = currentState.cipher - - } else { - newState.cipher = c.GetString("cipher", "aes") - //TODO: this sucks and we should make it not a global - switch newState.cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian - default: - return util.NewContextualError( - "unknown cipher", - m{"cipher": newState.cipher}, - nil, - ) - } } p.cs.Store(newState) if initial { - p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") + p.l.Debug("Client nebula certificate(s)", "cert", newState) } else { - p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") + p.l.Info("Client certificate(s) refreshed from disk", "cert", newState) } return nil } @@ -195,7 +202,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { } p.caPool.Store(caPool) - p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") + p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints()) return nil } @@ -207,6 +214,20 @@ func (cs *CertState) GetDefaultCertificate() cert.Certificate { return c } +// DefaultVersion returns the preferred cert version for initiating handshakes. +func (cs *CertState) DefaultVersion() cert.Version { return cs.initiatingVersion } + +// GetCredential returns the pre-computed handshake credential for the given version, or nil. +func (cs *CertState) GetCredential(v cert.Version) *handshake.Credential { + switch v { + case cert.Version1: + return cs.v1Credential + case cert.Version2: + return cs.v2Credential + } + return nil +} + func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { switch v { case cert.Version1: @@ -218,17 +239,25 @@ func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { return nil } -// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. -// Callers must check if the return []byte is nil. -func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { - switch v { - case cert.Version1: - return cs.v1HandshakeBytes - case cert.Version2: - return cs.v2HandshakeBytes +func newCipherSuite(curve cert.Curve, pkcs11backed bool, cipher string) (noise.CipherSuite, error) { + var dhFunc noise.DHFunc + switch curve { + case cert.Curve_CURVE25519: + dhFunc = noise.DH25519 + case cert.Curve_P256: + if pkcs11backed { + dhFunc = noiseutil.DHP256PKCS11 + } else { + dhFunc = noiseutil.DHP256 + } default: - return nil + return nil, fmt.Errorf("unsupported curve: %s", curve) } + + if cipher == "chachapoly" { + return noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256), nil + } + return noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256), nil } func (cs *CertState) String() string { @@ -260,7 +289,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) { return json.Marshal(msg) } -func newCertStateFromConfig(c *config.C) (*CertState, error) { +func newCertStateFromConfig(c *config.C, cipher string) (*CertState, error) { var err error privPathOrPEM := c.GetString("pki.key", "") @@ -344,13 +373,14 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion) } - return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey) + return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey, cipher) } -func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { +func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte, cipher string) (*CertState, error) { cs := CertState{ privateKey: privateKey, pkcs11Backed: pkcs11backed, + cipher: cipher, myVpnNetworksTable: new(bart.Lite), myVpnAddrsTable: new(bart.Lite), myVpnBroadcastAddrsTable: new(bart.Lite), @@ -383,10 +413,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p v1hs, err := v1.MarshalForHandshakes() if err != nil { - return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + return nil, fmt.Errorf("error marshalling v1 certificate for handshake: %w", err) + } + ncs, err := newCipherSuite(v1.Curve(), pkcs11backed, cipher) + if err != nil { + return nil, err } cs.v1Cert = v1 - cs.v1HandshakeBytes = v1hs + cs.v1Credential = handshake.NewCredential(v1, v1hs, privateKey, ncs) if cs.initiatingVersion == 0 { cs.initiatingVersion = cert.Version1 @@ -404,10 +438,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p v2hs, err := v2.MarshalForHandshakes() if err != nil { - return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + return nil, fmt.Errorf("error marshalling v2 certificate for handshake: %w", err) + } + ncs, err := newCipherSuite(v2.Curve(), pkcs11backed, cipher) + if err != nil { + return nil, err } cs.v2Cert = v2 - cs.v2HandshakeBytes = v2hs + cs.v2Credential = handshake.NewCredential(v2, v2hs, privateKey, ncs) if cs.initiatingVersion == 0 { cs.initiatingVersion = cert.Version2 @@ -486,32 +524,32 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) { return c, b, nil } -func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { - var rawCA []byte - var err error - +func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) { caPathOrPEM := c.GetString("pki.ca", "") if caPathOrPEM == "" { return nil, errors.New("no pki.ca path or PEM data provided") } - if strings.Contains(caPathOrPEM, "-----BEGIN") { - rawCA = []byte(caPathOrPEM) + var caReader io.ReadCloser + var err error + if strings.Contains(caPathOrPEM, "-----BEGIN") { + caReader = io.NopCloser(strings.NewReader(caPathOrPEM)) } else { - rawCA, err = os.ReadFile(caPathOrPEM) + caReader, err = os.Open(caPathOrPEM) if err != nil { return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) } } + defer caReader.Close() - caPool, err := cert.NewCAPoolFromPEM(rawCA) + caPool, err := cert.NewCAPoolFromPEMReader(caReader) if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { if crt.Certificate.Expired(time.Now()) { expired++ - l.WithField("cert", crt).Warn("expired certificate present in CA pool") + l.Warn("expired certificate present in CA pool", "cert", crt) } } @@ -529,7 +567,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { caPool.BlocklistFingerprint(fp) } - l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates") + l.Info("Blocklisted certificates", "fingerprintCount", len(bl)) } return caPool, nil diff --git a/pki_hup_benchmark_test.go b/pki_hup_benchmark_test.go new file mode 100644 index 00000000..bca23d78 --- /dev/null +++ b/pki_hup_benchmark_test.go @@ -0,0 +1,121 @@ +package nebula + +import ( + "bytes" + "fmt" + "net/netip" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + cert_test "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/require" +) + +func BenchmarkReloadConfigWithCAs(b *testing.B) { + prevProcs := runtime.GOMAXPROCS(1) + b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) }) + + for _, size := range []int{100, 250, 500, 1000, 5000} { + b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) { + l := test.NewLogger() + dir := b.TempDir() + + ca, caKey, caBundle := buildCABundle(b, size) + caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle) + + configBody := fmt.Sprintf(`pki: + ca: %s + cert: %s + key: %s +`, caPath, certPath, keyPath) + + configPath := filepath.Join(dir, "config.yml") + require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600)) + + c := config.NewC(l) + require.NoError(b, c.Load(dir)) + + _, err := NewPKIFromConfig(test.NewLogger(), c) + require.NoError(b, err) + + b.ReportAllocs() + b.ResetTimer() + + for b.Loop() { + c.ReloadConfig() + } + }) + } +} + +func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) { + b.Helper() + require.GreaterOrEqual(b, count, 1) + + before := time.Now().Add(-24 * time.Hour) + after := time.Now().Add(24 * time.Hour) + + ca, _, caKey, pem := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + before, + after, + nil, + nil, + nil, + ) + + buf := bytes.NewBuffer(pem) + buf.Write([]byte("\n# a comment!\n")) + + for i := 1; i < count; i++ { + _, _, _, extraPEM := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + time.Now(), + time.Now().Add(time.Hour), + nil, + nil, + nil, + ) + buf.Write([]byte("\n# a comment!\n")) + buf.Write(extraPEM) + } + + return ca, caKey, buf.Bytes() +} + +func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) { + b.Helper() + + networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")} + + _, _, keyPEM, certPEM := cert_test.NewTestCert( + cert.Version2, + cert.Curve_CURVE25519, + ca, + caKey, + "reload-benchmark", + time.Now(), + time.Now().Add(time.Hour), + networks, + nil, + nil, + ) + + caPath := filepath.Join(dir, "ca.pem") + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + require.NoError(b, os.WriteFile(caPath, caBundle, 0o600)) + require.NoError(b, os.WriteFile(certPath, certPEM, 0o600)) + require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600)) + + return caPath, certPath, keyPath +} diff --git a/punchy.go b/punchy.go index 2034405a..6ecf4f85 100644 --- a/punchy.go +++ b/punchy.go @@ -1,10 +1,10 @@ package nebula import ( + "log/slog" "sync/atomic" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) @@ -14,10 +14,10 @@ type Punchy struct { delay atomic.Int64 respondDelay atomic.Int64 punchEverything atomic.Bool - l *logrus.Logger + l *slog.Logger } -func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { +func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { p := &Punchy{l: l} p.reload(c, true) @@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) { p.respond.Store(yes) if !initial { - p.l.Infof("punchy.respond changed to %v", p.GetRespond()) + p.l.Info("punchy.respond changed", "respond", p.GetRespond()) } } @@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) { if initial || c.HasChanged("punchy.delay") { p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) if !initial { - p.l.Infof("punchy.delay changed to %s", p.GetDelay()) + p.l.Info("punchy.delay changed", "delay", p.GetDelay()) } } if initial || c.HasChanged("punchy.target_all_remotes") { p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) if !initial { - p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed") + p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything()) } } if initial || c.HasChanged("punchy.respond_delay") { p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) if !initial { - p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay()) + p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay()) } } } diff --git a/punchy_test.go b/punchy_test.go index 56dd1c25..cbf9b17b 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -1,6 +1,8 @@ package nebula import ( + "context" + "log/slog" "testing" "time" @@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) { c := config.NewC(l) // Test defaults - p := NewPunchyFromConfig(l, c) + p := NewPunchyFromConfig(test.NewLogger(), c) assert.False(t, p.GetPunch()) assert.False(t, p.GetRespond()) assert.Equal(t, time.Second, p.GetDelay()) @@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) { // punchy deprecation c.Settings["punchy"] = true - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetPunch()) // punchy.punch c.Settings["punchy"] = map[string]any{"punch": true} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetPunch()) // punch_back deprecation c.Settings["punch_back"] = true - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetRespond()) // punchy.respond c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetRespond()) // punchy.delay c.Settings["punchy"] = map[string]any{"delay": "1m"} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, time.Minute, p.GetDelay()) // punchy.respond_delay c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, time.Minute, p.GetRespondDelay()) } @@ -62,7 +64,7 @@ punchy: delay: 1m respond: false `)) - p := NewPunchyFromConfig(l, c) + p := NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, delay, p.GetDelay()) assert.False(t, p.GetRespond()) @@ -76,3 +78,158 @@ punchy: assert.Equal(t, newDelay, p.GetDelay()) assert.True(t, p.GetRespond()) } + +// The tests below pin the shape of each log line Punchy produces so changes +// cannot silently break whatever operators are grepping for. The assertions +// are on the structured message + attrs (e.g. "punchy.respond changed" with +// a respond=true field) rather than a formatted string. +// +// Punchy.reload also emits a spurious "Changing punchy.punch with reload is +// not supported" warning whenever any key under punchy changes, because of +// the c.HasChanged("punchy") fallback kept for the deprecated top-level +// punchy form. The tests filter by message rather than asserting total +// entry counts so that warning is tolerated without being locked into +// the format. + +type capturedEntry struct { + Level slog.Level + Msg string + Attrs map[string]any +} + +// capturingHandler is a slog.Handler that records each Record it receives so +// tests can assert on the level, message, and attribute map of individual log +// lines without coupling to any specific text format. +type capturingHandler struct { + entries []capturedEntry +} + +func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true } + +func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error { + e := capturedEntry{ + Level: r.Level, + Msg: r.Message, + Attrs: make(map[string]any), + } + r.Attrs(func(a slog.Attr) bool { + e.Attrs[a.Key] = a.Value.Resolve().Any() + return true + }) + h.entries = append(h.entries, e) + return nil +} + +func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } +func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h } + +func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) { + t.Helper() + hook := &capturingHandler{} + return slog.New(hook), hook +} + +func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry { + t.Helper() + for _, e := range entries { + if e.Msg == msg { + return e + } + } + t.Fatalf("no entry with message %q among %d entries", msg, len(entries)) + return capturedEntry{} +} + +func TestPunchy_LogFormat_InitialEnabled(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: true}`)) + + NewPunchyFromConfig(l, c) + + entry := findEntry(t, hook.entries, "punchy enabled") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_InitialDisabled(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: false}`)) + + NewPunchyFromConfig(l, c) + + entry := findEntry(t, hook.entries, "punchy disabled") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`)) + + entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.") + assert.Equal(t, slog.LevelWarn, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadRespond(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {respond: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`)) + + entry := findEntry(t, hook.entries, "punchy.respond changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"respond": true}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadDelay(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {delay: 1s}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`)) + + entry := findEntry(t, hook.entries, "punchy.delay changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`)) + + entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`)) + + entry := findEntry(t, hook.entries, "punchy.respond_delay changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs) +} diff --git a/relay_manager.go b/relay_manager.go index 5dd355ca..25e65871 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -5,22 +5,23 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net/netip" "sync/atomic" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) type relayManager struct { - l *logrus.Logger - hostmap *HostMap - amRelay atomic.Bool + l *slog.Logger + hostmap *HostMap + amRelay atomic.Bool + useRelays atomic.Bool } -func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager { +func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager { rm := &relayManager{ l: l, hostmap: hostmap, @@ -29,15 +30,17 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c c.RegisterReloadCallback(func(c *config.C) { err := rm.reload(c, false) if err != nil { - l.WithError(err).Error("Failed to reload relay_manager") + rm.l.Error("Failed to reload relay_manager", "error", err) } }) return rm } func (rm *relayManager) reload(c *config.C, initial bool) error { - if initial || c.HasChanged("relay.am_relay") { - rm.setAmRelay(c.GetBool("relay.am_relay", false)) + if initial || c.HasChanged("relay.am_relay") || c.HasChanged("relay.use_relays") { + amRelay := c.GetBool("relay.am_relay", false) + rm.amRelay.Store(amRelay) + rm.useRelays.Store(c.GetBool("relay.use_relays", true) && !amRelay) } return nil } @@ -46,16 +49,165 @@ func (rm *relayManager) GetAmRelay() bool { return rm.amRelay.Load() } -func (rm *relayManager) setAmRelay(v bool) { - rm.amRelay.Store(v) +func (rm *relayManager) GetUseRelays() bool { + return rm.useRelays.Load() +} + +// StartRelays drives the relay-establishment side of an outbound handshake attempt. +// For each candidate relay it either kicks off a handshake to the relay, sends a CreateRelayRequest, retransmits +// one that may have been lost, or, once the relay is Established, forwards the in-progress +// stage 0 handshake packet for vpnIp through it. +func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *HostInfo, stage0 []byte) { + if !rm.GetUseRelays() || len(hostinfo.remotes.relays) == 0 { + return + } + + hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) + // Send a RelayRequest to all known Relay IP's + for _, relay := range hostinfo.remotes.relays { + // Don't relay through the host I'm trying to connect to + if relay == vpnIp { + continue + } + + // Don't relay to myself + if f.myVpnAddrsTable.Contains(relay) { + continue + } + + relayHostInfo := rm.hostmap.QueryVpnAddr(relay) + if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { + hostinfo.logger(rm.l).Info("Establish tunnel to relay target", "relay", relay.String()) + f.Handshake(relay) + continue + } + // Check the relay HostInfo to see if we already established a relay through + existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) + if !ok { + // No relays exist or requested yet. + if relayHostInfo.remote.IsValid() { + idx, err := AddRelay(rm.l, relayHostInfo, rm.hostmap, vpnIp, nil, TerminalType, Requested) + if err != nil { + hostinfo.logger(rm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) + } + + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: idx, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !f.myVpnAddrs[0].Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + continue + } + + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + } else { + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.Info("send CreateRelayRequest", + "relayFrom", f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", idx, + "relay", relay, + ) + } + } + continue + } + + switch existingRelay.State { + case Established: + hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String()) + f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false) + case Disestablished: + // Mark this relay as 'requested' + relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) + fallthrough + case Requested: + hostinfo.logger(rm.l).Info("Re-send CreateRelay request", "relay", relay.String()) + // Re-send the CreateRelay request, in case the previous one was lost. + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: existingRelay.LocalIndex, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !f.myVpnAddrs[0].Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + } else { + // This must send over the hostinfo, not over hm.Hosts[ip] + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.Info("send CreateRelayRequest", + "relayFrom", f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", existingRelay.LocalIndex, + "relay", relay, + ) + } + case PeerRequested: + // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. + fallthrough + default: + hostinfo.logger(rm.l).Error("Relay unexpected state", + "vpnIp", vpnIp, + "state", existingRelay.State, + "relay", relay, + ) + + } + } } // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. -func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { +func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() - for i := 0; i < 32; i++ { + for range 32 { index, err := generateIndex(l) if err != nil { return 0, err @@ -92,24 +244,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - fields := logrus.Fields{ - "relay": relayHostInfo.vpnAddrs[0], - "initiatorRelayIndex": m.InitiatorRelayIndex, - } - + var relayFrom, relayTo any if m.RelayFromAddr == nil { - fields["relayFrom"] = m.OldRelayFromAddr + relayFrom = m.OldRelayFromAddr } else { - fields["relayFrom"] = m.RelayFromAddr + relayFrom = m.RelayFromAddr } - if m.RelayToAddr == nil { - fields["relayTo"] = m.OldRelayToAddr + relayTo = m.OldRelayToAddr } else { - fields["relayTo"] = m.RelayToAddr + relayTo = m.RelayToAddr } - rm.l.WithFields(fields).Info("relayManager failed to update relay") + rm.l.Info("relayManager failed to update relay", + "relay", relayHostInfo.vpnAddrs[0], + "initiatorRelayIndex", m.InitiatorRelayIndex, + "relayFrom", relayFrom, + "relayTo", relayTo, + ) return nil, fmt.Errorf("unknown relay") } @@ -120,7 +272,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { msg := &NebulaControl{} err := msg.Unmarshal(d) if err != nil { - h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") + h.logger(f.l).Error("Failed to unmarshal control message", "error", err) return } @@ -147,20 +299,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { } func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { - rm.l.WithFields(logrus.Fields{ - "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), - "relayTo": protoAddrToNetAddr(m.RelayToAddr), - "initiatorRelayIndex": m.InitiatorRelayIndex, - "responderRelayIndex": m.ResponderRelayIndex, - "vpnAddrs": h.vpnAddrs}). - Info("handleCreateRelayResponse") + rm.l.Info("handleCreateRelayResponse", + "relayFrom", protoAddrToNetAddr(m.RelayFromAddr), + "relayTo", protoAddrToNetAddr(m.RelayToAddr), + "initiatorRelayIndex", m.InitiatorRelayIndex, + "responderRelayIndex", m.ResponderRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) target := m.RelayToAddr targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { - rm.l.WithError(err).Error("Failed to update relay for relayTo") + rm.l.Error("Failed to update relay for relayTo", "error", err) return } // Do I need to complete the relays now? @@ -170,12 +322,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f // I'm the middle man. Let the initiator know that the I've established the relay they requested. peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { - rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") + rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr) return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { - rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") + rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0]) return } switch peerRelay.State { @@ -193,12 +345,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f if v == cert.Version1 { peer := peerHostInfo.vpnAddrs[0] if !peer.Is4() { - rm.l.WithField("relayFrom", peer). - WithField("relayTo", target). - WithField("initiatorRelayIndex", resp.InitiatorRelayIndex). - WithField("responderRelayIndex", resp.ResponderRelayIndex). - WithField("vpnAddrs", peerHostInfo.vpnAddrs). - Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address") + rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address", + "relayFrom", peer, + "relayTo", target, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", peerHostInfo.vpnAddrs, + ) return } @@ -213,17 +366,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f msg, err := resp.Marshal() if err != nil { - rm.l.WithError(err). - Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": resp.RelayFromAddr, - "relayTo": resp.RelayToAddr, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnAddrs": peerHostInfo.vpnAddrs}). - Info("send CreateRelayResponse") + rm.l.Info("send CreateRelayResponse", + "relayFrom", resp.RelayFromAddr, + "relayTo", resp.RelayToAddr, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", peerHostInfo.vpnAddrs, + ) } } } @@ -232,17 +384,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f from := protoAddrToNetAddr(m.RelayFromAddr) target := protoAddrToNetAddr(m.RelayToAddr) - logMsg := rm.l.WithFields(logrus.Fields{ - "relayFrom": from, - "relayTo": target, - "initiatorRelayIndex": m.InitiatorRelayIndex, - "vpnAddrs": h.vpnAddrs}) + logMsg := rm.l.With( + "relayFrom", from, + "relayTo", target, + "initiatorRelayIndex", m.InitiatorRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. if f.myVpnAddrsTable.Contains(from) { - logMsg.WithField("myIP", from).Error("Discarding relay request from myself") + logMsg.Error("Discarding relay request from myself", "myIP", from) return } @@ -261,37 +414,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + logMsg.Error("Existing relay mismatch with CreateRelayRequest", + "existingRemoteIndex", existingRelay.RemoteIndex) return } case Disestablished: if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + logMsg.Error("Existing relay mismatch with CreateRelayRequest", + "existingRemoteIndex", existingRelay.RemoteIndex) return } // Mark the relay as 'Established' because it's safe to use again h.relayState.UpdateRelayForByIpState(from, Established) case PeerRequested: // I should never be in this state, because I am terminal, not forwarding. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex, - "state": existingRelay.State}).Error("Unexpected Relay State found") + logMsg.Error("Unexpected Relay State found", + "existingRemoteIndex", existingRelay.RemoteIndex, + "state", existingRelay.State) } } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) if err != nil { - logMsg.WithError(err).Error("Failed to add relay") + logMsg.Error("Failed to add relay", "error", err) return } } relay, ok := h.relayState.QueryRelayForByIp(from) if !ok { - logMsg.WithField("from", from).Error("Relay State not found") + logMsg.Error("Relay State not found", "from", from) return } @@ -313,17 +466,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f msg, err := resp.Marshal() if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": from, - "relayTo": target, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnAddrs": h.vpnAddrs}). - Info("send CreateRelayResponse") + rm.l.Info("send CreateRelayResponse", + "relayFrom", from, + "relayTo", target, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) } return } else { @@ -363,12 +515,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if v == cert.Version1 { if !h.vpnAddrs[0].Is4() { - rm.l.WithField("relayFrom", h.vpnAddrs[0]). - WithField("relayTo", target). - WithField("initiatorRelayIndex", req.InitiatorRelayIndex). - WithField("responderRelayIndex", req.ResponderRelayIndex). - WithField("vpnAddr", target). - Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address") + rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address", + "relayFrom", h.vpnAddrs[0], + "relayTo", target, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddr", target, + ) return } @@ -383,17 +536,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f msg, err := req.Marshal() if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control message to create relay") + logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": h.vpnAddrs[0], - "relayTo": target, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnAddr": target}). - Info("send CreateRelayRequest") + rm.l.Info("send CreateRelayRequest", + "relayFrom", h.vpnAddrs[0], + "relayTo", target, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddr", target, + ) } // Also track the half-created Relay state just received @@ -401,8 +553,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if !ok { _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to allocate a local index for relay") + logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err) return } } diff --git a/remote_list.go b/remote_list.go index 1304fd51..7b95de87 100644 --- a/remote_list.go +++ b/remote_list.go @@ -2,6 +2,7 @@ package nebula import ( "context" + "log/slog" "net" "net/netip" "slices" @@ -10,8 +11,6 @@ import ( "sync" "sync/atomic" "time" - - "github.com/sirupsen/logrus" ) // forEachFunc is used to benefit folks that want to do work inside the lock @@ -66,11 +65,11 @@ type hostnamesResults struct { network string lookupTimeout time.Duration cancelFn func() - l *logrus.Logger + l *slog.Logger ips atomic.Pointer[map[netip.AddrPort]struct{}] } -func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { +func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { r := &hostnamesResults{ hostnames: make([]hostnamePort, len(hostPorts)), network: network, @@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) timeoutCancel() if err != nil { - l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") + l.Error("DNS resolution failed for static_map host", + "hostname", hostPort.name, + "network", r.network, + "error", err, + ) continue } for _, a := range addrs { @@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, } } if different { - l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") + l.Info("DNS results changed for host list", + "origSet", origSet, + "newSet", netipAddrs, + ) r.ips.Store(&netipAddrs) onUpdate() } @@ -404,12 +410,7 @@ func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { - for _, v := range r.badRemotes { - if v == remote { - return true - } - } - return false + return slices.Contains(r.badRemotes, remote) } // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the diff --git a/service/service.go b/service/service.go index fc8ac97a..899e851d 100644 --- a/service/service.go +++ b/service/service.go @@ -44,7 +44,10 @@ type Service struct { } func New(control *nebula.Control) (*Service, error) { - control.Start() + wait, err := control.Start() + if err != nil { + return nil, err + } ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) @@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) { } }) + // Add the nebula wait function to the group so a fatal reader error + // propagates out through errgroup.Wait(). + eg.Go(func() error { + return wait() + }) + return &s, nil } diff --git a/service/service_test.go b/service/service_test.go index c6b87423..4bcc8437 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,11 +10,11 @@ import ( "time" "dario.cat/mergo" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/overlay" "go.yaml.in/yaml/v3" "golang.org/x/sync/errgroup" @@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n panic(err) } - logger := logrus.New() - logger.Out = os.Stdout + logger := logging.NewLogger(os.Stdout) control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { diff --git a/ssh.go b/ssh.go index 9a26c290..3863b5ec 100644 --- a/ssh.go +++ b/ssh.go @@ -6,19 +6,21 @@ import ( "errors" "flag" "fmt" + "log/slog" + "maps" "net" "net/netip" "os" - "reflect" + "path/filepath" "runtime" "runtime/pprof" "sort" "strconv" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/sshd" ) @@ -55,12 +57,12 @@ type sshDeviceInfoFlags struct { Pretty bool } -func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { +func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) { c.RegisterReloadCallback(func(c *config.C) { if c.GetBool("sshd.enabled", false) { sshRun, err := configSSH(l, ssh, c) if err != nil { - l.WithError(err).Error("Failed to reconfigure the sshd") + l.Error("Failed to reconfigure the sshd", "error", err) ssh.Stop() } if sshRun != nil { @@ -76,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { // updates the passed-in SSHServer. On success, it returns a function // that callers may invoke to run the configured ssh server. On // failure, it returns nil, error. -func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { +func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { listen := c.GetString("sshd.listen", "") if listen == "" { return nil, fmt.Errorf("sshd.listen must be provided") @@ -118,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, caAuthorizedKey := range rawCAs { err := ssh.AddTrustedCA(caAuthorizedKey) if err != nil { - l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring") + l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey) continue } } @@ -129,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, rk := range keys { kDef, ok := rk.(map[string]any) if !ok { - l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") + l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk) continue } user, ok := kDef["user"].(string) if !ok { - l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field") + l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk) continue } @@ -144,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro case string: err := ssh.AddAuthorizedKey(user, v) if err != nil { - l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key") + l.Warn("Failed to authorize key", + "error", err, + "sshKeyConfig", rk, + "sshKey", v, + ) continue } @@ -152,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, subK := range v { sk, ok := subK.(string) if !ok { - l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key") + l.Warn("Did not understand ssh key", + "sshKeyConfig", rk, + "sshKey", subK, + ) continue } err := ssh.AddAuthorizedKey(user, sk) if err != nil { - l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key") + l.Warn("Failed to authorize key", + "error", err, + "sshKeyConfig", sk, + ) continue } } default: - l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood") + l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk) } } } else { @@ -176,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro ssh.Stop() runner = func() { if err := ssh.Run(listen); err != nil { - l.WithField("err", err).Warn("Failed to run the SSH server") + l.Warn("Failed to run the SSH server", "error", err) } } } else { @@ -186,7 +198,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return runner, nil } -func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { +func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { + // sandboxDir defaults to a dir in temp. The intention is that end user will + // create this dir as needed. Overriding this config value to "" allows + // writing to anywhere in the system. + defaultDir := filepath.Join(os.TempDir(), "nebula-debug") + sandboxDir := c.GetString("sshd.sandbox_dir", defaultDir) + ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", @@ -245,7 +263,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "start-cpu-profile", ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`", - Callback: sshStartCpuProfile, + Callback: func(fs any, a []string, w sshd.StringWriter) error { + return sshStartCpuProfile(sandboxDir, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -260,7 +280,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "save-heap-profile", ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`", - Callback: sshGetHeapProfile, + Callback: func(fs any, a []string, w sshd.StringWriter) error { + return sshGetHeapProfile(sandboxDir, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -272,7 +294,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "save-mutex-profile", ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`", - Callback: sshGetMutexProfile, + Callback: func(fs any, a []string, w sshd.StringWriter) error { + return sshGetMutexProfile(sandboxDir, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -505,13 +529,43 @@ func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) er return nil } -func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error { +// sshSanitizeFilePath validates that the given file path is within the sandbox directory. +// If sandboxDir is empty, the path is returned as-is for backwards compatibility. +func sshSanitizeFilePath(sandboxDir, filePath string) (string, error) { + if sandboxDir == "" { + return filePath, nil + } + + // Clean and resolve the path relative to the sandbox directory + if !filepath.IsAbs(filePath) { + filePath = filepath.Join(sandboxDir, filePath) + } + cleaned := filepath.Clean(filePath) + + // Ensure the resolved path is within the sandbox directory + cleanedSandbox := filepath.Clean(sandboxDir) + if cleaned == cleanedSandbox { + return "", fmt.Errorf("path %q resolves to the sandbox directory itself %q", filePath, sandboxDir) + } + if !strings.HasPrefix(cleaned, cleanedSandbox+string(filepath.Separator)) { + return "", fmt.Errorf("path %q is outside the sandbox directory %q", filePath, sandboxDir) + } + + return cleaned, nil +} + +func sshStartCpuProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { err := w.WriteLine("No path to write profile provided") return err } - file, err := os.Create(a[0]) + filePath, err := sshSanitizeFilePath(sandboxDir, a[0]) + if err != nil { + return w.WriteLine(err.Error()) + } + + file, err := os.Create(filePath) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) return err @@ -675,12 +729,17 @@ func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) e return w.WriteLine("Changed") } -func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error { +func sshGetHeapProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } - file, err := os.Create(a[0]) + filePath, err := sshSanitizeFilePath(sandboxDir, a[0]) + if err != nil { + return w.WriteLine(err.Error()) + } + + file, err := os.Create(filePath) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) return err @@ -711,12 +770,17 @@ func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate)) } -func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { +func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } - file, err := os.Create(a[0]) + filePath, err := sshSanitizeFilePath(sandboxDir, a[0]) + if err != nil { + return w.WriteLine(err.Error()) + } + + file, err := os.Create(filePath) if err != nil { return w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) } @@ -735,36 +799,45 @@ func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) } -func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error { + ctrl, ok := l.Handler().(interface { + GetLevel() slog.Level + SetLevel(slog.Level) + }) + if !ok { + return w.WriteLine("Log level is not reconfigurable on this logger") + } + if len(a) == 0 { - return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) + return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel()))) } - level, err := logrus.ParseLevel(a[0]) + level, err := logging.ParseLevel(strings.ToLower(a[0])) if err != nil { - return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels)) + return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a)) } - l.SetLevel(level) - return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) + ctrl.SetLevel(level) + return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel()))) } -func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error { + ctrl, ok := l.Handler().(interface { + GetFormat() string + SetFormat(string) error + }) + if !ok { + return w.WriteLine("Log format is not reconfigurable on this logger") + } + if len(a) == 0 { - return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) + return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat())) } - logFormat := strings.ToLower(a[0]) - switch logFormat { - case "text": - l.Formatter = &logrus.TextFormatter{} - case "json": - l.Formatter = &logrus.JSONFormatter{} - default: - return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) + if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil { + return err } - - return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) + return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat())) } func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { @@ -831,9 +904,7 @@ func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) er relays := map[uint32]*HostInfo{} ifce.hostMap.Lock() - for k, v := range ifce.hostMap.Relays { - relays[k] = v - } + maps.Copy(relays, ifce.hostMap.Relays) ifce.hostMap.Unlock() type RelayFor struct { diff --git a/sshd/server.go b/sshd/server.go index a8b60ba7..38886e53 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -2,19 +2,19 @@ package sshd import ( "bytes" + "context" "errors" "fmt" + "log/slog" "net" - "sync" "github.com/armon/go-radix" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) type SSHServer struct { config *ssh.ServerConfig - l *logrus.Entry + l *slog.Logger certChecker *ssh.CertChecker @@ -27,20 +27,21 @@ type SSHServer struct { commands *radix.Tree listener net.Listener - // Locks the conns/counter to avoid concurrent map access - connsLock sync.Mutex - conns map[int]*session - counter int + // Call the cancel() function to stop all active sessions + ctx context.Context + cancel func() } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen -func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { +func NewSSHServer(l *slog.Logger) (*SSHServer, error) { + ctx, cancel := context.WithCancel(context.Background()) s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, commands: radix.New(), - conns: make(map[int]*session), + ctx: ctx, + cancel: cancel, } cc := ssh.CertChecker{ @@ -120,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error { } s.trustedCAs = append(s.trustedCAs, pk) - s.l.WithField("sshKey", pubKey).Info("Trusted CA key") + s.l.Info("Trusted CA key", "sshKey", pubKey) return nil } @@ -138,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error { } tk[string(pk.Marshal())] = true - s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key") + s.l.Info("Authorized ssh key", + "sshKey", pubKey, + "sshUser", user, + ) return nil } @@ -155,7 +159,7 @@ func (s *SSHServer) Run(addr string) error { return err } - s.l.WithField("sshListener", addr).Info("SSH server is listening") + s.l.Info("SSH server is listening", "sshListener", addr) // Run loops until there is an error s.run() @@ -171,48 +175,54 @@ func (s *SSHServer) run() { c, err := s.listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { - s.l.WithError(err).Warn("Error in listener, shutting down") + s.l.Warn("Error in listener, shutting down", "error", err) } return } - - conn, chans, reqs, err := ssh.NewServerConn(c, s.config) - fp := "" - if conn != nil { - fp = conn.Permissions.Extensions["fp"] - } - - if err != nil { - l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) + go func(c net.Conn) { + // NewServerConn may block while waiting for the client to complete the handshake. + // Ensure that a bad client doesn't hurt us by checking for the parent context + // cancellation before calling NewServerConn, and forcing the socket to close when + // the context is cancelled. + sessionContext, sessionCancel := context.WithCancel(s.ctx) + go func() { + <-sessionContext.Done() + c.Close() + }() + conn, chans, reqs, err := ssh.NewServerConn(c, s.config) + fp := "" if conn != nil { - l = l.WithField("sshUser", conn.User()) - conn.Close() + fp = conn.Permissions.Extensions["fp"] } - if fp != "" { - l = l.WithField("sshFingerprint", fp) + + if err != nil { + l := s.l.With( + "error", err, + "remoteAddress", c.RemoteAddr(), + ) + if conn != nil { + l = l.With("sshUser", conn.User()) + conn.Close() + } + if fp != "" { + l = l.With("sshFingerprint", fp) + } + l.Warn("failed to handshake") + sessionCancel() + return } - l.Warn("failed to handshake") - continue - } - l := s.l.WithField("sshUser", conn.User()) - l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") + l := s.l.With("sshUser", conn.User()) + l.Info("ssh user logged in", + "remoteAddress", c.RemoteAddr(), + "sshFingerprint", fp, + ) - session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) - s.connsLock.Lock() - s.counter++ - counter := s.counter - s.conns[counter] = session - s.connsLock.Unlock() + NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session")) - go ssh.DiscardRequests(reqs) - go func() { - <-session.exitChan - s.l.WithField("id", counter).Debug("closing conn") - s.connsLock.Lock() - delete(s.conns, counter) - s.connsLock.Unlock() - }() + go ssh.DiscardRequests(reqs) + + }(c) } } @@ -220,15 +230,11 @@ func (s *SSHServer) Stop() { // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil { if err := s.listener.Close(); err != nil { - s.l.WithError(err).Warn("Failed to close the sshd listener") + s.l.Warn("Failed to close the sshd listener", "error", err) } } } func (s *SSHServer) closeSessions() { - s.connsLock.Lock() - for _, c := range s.conns { - c.Close() - } - s.connsLock.Unlock() + s.cancel() } diff --git a/sshd/session.go b/sshd/session.go index 87cc216f..1c8e1a9b 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -2,30 +2,30 @@ package sshd import ( "fmt" + "log/slog" "sort" "strings" "github.com/anmitsu/go-shlex" "github.com/armon/go-radix" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "golang.org/x/term" ) type session struct { - l *logrus.Entry + l *slog.Logger c *ssh.ServerConn term *term.Terminal commands *radix.Tree - exitChan chan bool + cancel func() } -func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session { +func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *slog.Logger) *session { s := &session{ commands: radix.NewFromMap(commands.ToMap()), l: l, c: conn, - exitChan: make(chan bool), + cancel: cancel, } s.commands.Insert("logout", &Command{ @@ -42,16 +42,17 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New } func (s *session) handleChannels(chans <-chan ssh.NewChannel) { + defer s.Close() for newChannel := range chans { if newChannel.ChannelType() != "session" { - s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") + s.l.Error("unknown channel type", "sshChannelType", newChannel.ChannelType()) newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } channel, requests, err := newChannel.Accept() if err != nil { - s.l.WithError(err).Warn("could not accept channel") + s.l.Warn("could not accept channel", "error", err) continue } @@ -94,13 +95,12 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { return default: - s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request") + s.l.Debug("Rejected unknown request", "sshRequest", req.Type) err = req.Reply(false, nil) } if err != nil { - s.l.WithError(err).Info("Error handling ssh session requests") - s.Close() + s.l.Info("Error handling ssh session requests", "error", err) return } } @@ -123,12 +123,11 @@ func (s *session) createTerm(channel ssh.Channel) *term.Terminal { return "", 0, false } - go s.handleInput(channel) + go s.handleInput() return term } -func (s *session) handleInput(channel ssh.Channel) { - defer s.Close() +func (s *session) handleInput() { w := &stringWriter{w: s.term} for { line, err := s.term.ReadLine() @@ -170,10 +169,9 @@ func (s *session) dispatchCommand(line string, w StringWriter) { } _ = execCommand(c, args[1:], w) - return } func (s *session) Close() { s.c.Close() - s.exitChan <- true + s.cancel() } diff --git a/stats.go b/stats.go index c88c45cc..97ce7cf5 100644 --- a/stats.go +++ b/stats.go @@ -1,13 +1,16 @@ package nebula import ( + "context" "errors" "fmt" - "log" + "log/slog" "net" "net/http" "runtime" "strconv" + "sync" + "sync/atomic" "time" graphite "github.com/cyberdelia/go-metrics-graphite" @@ -15,113 +18,350 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) -// startStats initializes stats from config. On success, if any further work -// is needed to serve stats, it returns a func to handle that work. If no -// work is needed, it'll return nil. On failure, it returns nil, error. -func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { - mType := c.GetString("stats.type", "") - if mType == "" || mType == "none" { - return nil, nil - } +// statsServer owns nebula's stats subsystem: the periodic metric capture +// goroutine and (for prometheus) an HTTP listener. It mirrors the lifecycle +// shape of dnsServer: constructor wires the reload callback, reload records +// config, Start builds and runs the runtime, Stop tears it down. +type statsServer struct { + l *slog.Logger + ctx context.Context + buildVersion string + configTest bool - interval := c.GetDuration("stats.interval", 0) - if interval == 0 { - return nil, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) - } + // enabled mirrors "stats configured to a real backend". Start consults + // it so callers don't need to know the gating rules. + enabled atomic.Bool - var startFn func() - switch mType { - case "graphite": - err := startGraphiteStats(l, interval, c, configTest) - if err != nil { - return nil, err - } - case "prometheus": - var err error - startFn, err = startPrometheusStats(l, interval, c, buildVersion, configTest) - if err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("stats.type was not understood: %s", mType) - } - - metrics.RegisterDebugGCStats(metrics.DefaultRegistry) - metrics.RegisterRuntimeMemStats(metrics.DefaultRegistry) - - go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval) - go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval) - - return startFn, nil + runMu sync.Mutex + runCfg *statsConfig + run *statsRuntime // non-nil while a runtime is live } -func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error { - proto := c.GetString("stats.protocol", "tcp") - host := c.GetString("stats.host", "") - if host == "" { - return errors.New("stats.host can not be empty") +// statsRuntime is the live state owned by a single Start invocation. Start +// stashes a pointer under runMu; Stop and Start's own exit path use pointer +// equality to tell "my runtime" apart from one that replaced it after a +// reload. +type statsRuntime struct { + cancel context.CancelFunc + listener *http.Server // nil for graphite +} + +// statsConfig is the snapshot of stats-related config that drives the runtime. +// It is comparable with == so reload can detect "no change" cheaply. +type statsConfig struct { + typ string + interval time.Duration + graphite graphiteConfig + prom promConfig +} + +type graphiteConfig struct { + protocol string + host string + // resolvedAddr is the string form of host resolved at config-load time. + // Including it in the struct means a SIGHUP picks up DNS changes even + // when stats.host hasn't been edited. + resolvedAddr string + prefix string +} + +type promConfig struct { + listen string + path string + namespace string + subsystem string +} + +// newStatsServerFromConfig builds a statsServer, applies the initial config, +// and registers a reload callback. The reload callback is registered before +// the initial config is applied so a SIGHUP can later enable, fix, or disable +// stats even if the initial application failed. +// +// Start is safe to call unconditionally: it no-ops when stats are disabled. +// The returned pointer is always non-nil, even on error. +func newStatsServerFromConfig(ctx context.Context, l *slog.Logger, c *config.C, buildVersion string, configTest bool) (*statsServer, error) { + s := &statsServer{ + l: l, + ctx: ctx, + buildVersion: buildVersion, + configTest: configTest, } - prefix := c.GetString("stats.prefix", "nebula") - addr, err := net.ResolveTCPAddr(proto, host) + c.RegisterReloadCallback(func(c *config.C) { + if err := s.reload(c, false); err != nil { + s.l.Error("Failed to reload stats from config", "error", err) + } + }) + + if err := s.reload(c, true); err != nil { + return s, err + } + return s, nil +} + +// reload records the latest config. On the initial call it only records it; +// Control.Start is what launches the first runtime via statsStart. On later +// calls it reconciles the running runtime with the new config: +// +// - newly enabled -> spawn Start +// - newly disabled -> Stop the runtime +// - config changed (still enabled) -> Stop the old, Start the new +// - no change -> no-op +func (s *statsServer) reload(c *config.C, initial bool) error { + newCfg, err := loadStatsConfig(c) if err != nil { - return fmt.Errorf("error while setting up graphite sink: %s", err) + return err + } + enabled := newCfg.typ != "" && newCfg.typ != "none" + + s.runMu.Lock() + sameCfg := s.runCfg != nil && *s.runCfg == newCfg + s.runCfg = &newCfg + running := s.run != nil + s.runMu.Unlock() + + s.enabled.Store(enabled) + + if initial || sameCfg { + return nil } - if !configTest { - l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) - go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) + if running { + s.Stop() + } + if enabled && !s.configTest { + go s.Start() } return nil } -func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { - namespace := c.GetString("stats.namespace", "") - subsystem := c.GetString("stats.subsystem", "") - - listen := c.GetString("stats.listen", "") - if listen == "" { - return nil, fmt.Errorf("stats.listen should not be empty") +// Start builds the runtime from the latest config, spawns the capture loop, +// and blocks until Stop is called or ctx fires. For prometheus it also serves +// the HTTP listener. For graphite it blocks on the capture loop's context. +// Safe to call when stats are disabled or already running (both no-op). +func (s *statsServer) Start() { + if !s.enabled.Load() || s.configTest { + return } - path := c.GetString("stats.path", "") - if path == "" { - return nil, fmt.Errorf("stats.path should not be empty") + s.runMu.Lock() + if s.ctx.Err() != nil || s.run != nil || s.runCfg == nil { + s.runMu.Unlock() + return + } + cfg := *s.runCfg + captureFns, listener := s.buildRuntime(cfg) + runCtx, cancel := context.WithCancel(s.ctx) + rt := &statsRuntime{cancel: cancel, listener: listener} + s.run = rt + s.runMu.Unlock() + + go captureStatsLoop(runCtx, cfg.interval, captureFns) + + cleanExit := true + if listener == nil { + // Graphite: no HTTP listener to serve; block until teardown. + <-runCtx.Done() + } else { + cleanExit = s.serveListener(listener) } - pr := prometheus.NewRegistry() - pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i) - if !configTest { - go pClient.UpdatePrometheusMetrics() - } - - // Export our version information as labels on a static gauge - g := prometheus.NewGauge(prometheus.GaugeOpts{ - Namespace: namespace, - Subsystem: subsystem, - Name: "info", - Help: "Version information for the Nebula binary", - ConstLabels: prometheus.Labels{ - "version": buildVersion, - "goversion": runtime.Version(), - "boringcrypto": strconv.FormatBool(boringEnabled()), - }, - }) - pr.MustRegister(g) - g.Set(1) - - var startFn func() - if !configTest { - startFn = func() { - l.Infof("Prometheus stats listening on %s at %s", listen, path) - http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) - log.Fatal(http.ListenAndServe(listen, nil)) + // Clear our runtime only if nothing has replaced it. Stop races through + // here too but leaves s.run == nil, so the pointer check skips. + s.runMu.Lock() + if s.run == rt { + rt.cancel() + s.run = nil + // A listener that exited with an error (e.g., bind conflict) leaves + // runCfg cached as if it were applied. Drop it so a SIGHUP with the + // same config re-triggers Start once the user fixes the underlying + // problem. + if !cleanExit { + s.runCfg = nil } } - - return startFn, nil + s.runMu.Unlock() +} + +// serveListener runs ListenAndServe and ensures ctx cancellation unblocks it. +// Returns true if the listener exited cleanly (Stop, ctx cancellation, or any +// other http.ErrServerClosed path), false on an unexpected error. +func (s *statsServer) serveListener(listener *http.Server) bool { + // Per-invocation watcher: ctx cancellation triggers a listener shutdown + // which in turn unblocks ListenAndServe. Closing `done` on exit keeps + // the watcher from outliving this call. + done := make(chan struct{}) + go func() { + select { + case <-s.ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := listener.Shutdown(shutdownCtx); err != nil { + s.l.Warn("Failed to shut down prometheus stats listener", "error", err) + } + case <-done: + } + }() + defer close(done) + + s.l.Info("Starting prometheus stats listener", "addr", listener.Addr) + err := listener.ListenAndServe() + if err == nil || errors.Is(err, http.ErrServerClosed) { + return true + } + s.l.Error("Prometheus stats listener exited", "error", err) + return false +} + +// Stop tears down the active runtime, if any. Idempotent. +func (s *statsServer) Stop() { + s.runMu.Lock() + rt := s.run + s.run = nil + s.runMu.Unlock() + if rt == nil { + return + } + rt.cancel() + if rt.listener != nil { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := rt.listener.Shutdown(shutdownCtx); err != nil { + s.l.Warn("Failed to shut down prometheus stats listener", "error", err) + } + cancel() + } +} + +// buildRuntime produces the capture functions and, for prometheus, an un-served +// http.Server from cfg. cfg has already been validated by loadStatsConfig. +func (s *statsServer) buildRuntime(cfg statsConfig) ([]func(), *http.Server) { + // rcrowley/go-metrics guards these registrations with a private sync.Once, + // so subsequent reloads are no-ops. + metrics.RegisterDebugGCStats(metrics.DefaultRegistry) + metrics.RegisterRuntimeMemStats(metrics.DefaultRegistry) + + captureFns := []func(){ + func() { metrics.CaptureDebugGCStatsOnce(metrics.DefaultRegistry) }, + func() { metrics.CaptureRuntimeMemStatsOnce(metrics.DefaultRegistry) }, + } + + switch cfg.typ { + case "graphite": + // loadStatsConfig already resolved and validated the address; re-parse + // the resolved form (no DNS lookup) to get a *net.TCPAddr. + addr, _ := net.ResolveTCPAddr(cfg.graphite.protocol, cfg.graphite.resolvedAddr) + gcfg := graphite.Config{ + Addr: addr, + Registry: metrics.DefaultRegistry, + FlushInterval: cfg.interval, + DurationUnit: time.Nanosecond, + Prefix: cfg.graphite.prefix, + Percentiles: []float64{0.5, 0.75, 0.95, 0.99, 0.999}, + } + captureFns = append(captureFns, func() { + if err := graphite.Once(gcfg); err != nil { + s.l.Error("Graphite export failed", "error", err) + } + }) + s.l.Info("Starting graphite stats", + "interval", cfg.interval, + "prefix", cfg.graphite.prefix, + "addr", addr, + ) + return captureFns, nil + + case "prometheus": + pr := prometheus.NewRegistry() + pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, cfg.prom.namespace, cfg.prom.subsystem, pr, cfg.interval) + captureFns = append(captureFns, func() { + if err := pClient.UpdatePrometheusMetricsOnce(); err != nil { + s.l.Error("Prometheus metrics update failed", "error", err) + } + }) + + g := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: cfg.prom.namespace, + Subsystem: cfg.prom.subsystem, + Name: "info", + Help: "Version information for the Nebula binary", + ConstLabels: prometheus.Labels{ + "version": s.buildVersion, + "goversion": runtime.Version(), + "boringcrypto": strconv.FormatBool(boringEnabled()), + }, + }) + pr.MustRegister(g) + g.Set(1) + + // promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger, + // so bridge our slog.Logger back to a *log.Logger that emits at Error. + errLog := slog.NewLogLogger(s.l.Handler(), slog.LevelError) + mux := http.NewServeMux() + mux.Handle(cfg.prom.path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog})) + return captureFns, &http.Server{Addr: cfg.prom.listen, Handler: mux} + } + return captureFns, nil +} + +// captureStatsLoop runs each fn on every tick of d until ctx is cancelled. +func captureStatsLoop(ctx context.Context, d time.Duration, fns []func()) { + t := time.NewTicker(d) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + for _, fn := range fns { + fn() + } + } + } +} + +func loadStatsConfig(c *config.C) (statsConfig, error) { + cfg := statsConfig{ + typ: c.GetString("stats.type", ""), + } + if cfg.typ == "" || cfg.typ == "none" { + return cfg, nil + } + + cfg.interval = c.GetDuration("stats.interval", 0) + if cfg.interval == 0 { + return cfg, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) + } + + switch cfg.typ { + case "graphite": + cfg.graphite.protocol = c.GetString("stats.protocol", "tcp") + cfg.graphite.host = c.GetString("stats.host", "") + if cfg.graphite.host == "" { + return cfg, errors.New("stats.host can not be empty") + } + addr, err := net.ResolveTCPAddr(cfg.graphite.protocol, cfg.graphite.host) + if err != nil { + return cfg, fmt.Errorf("error while setting up graphite sink: %s", err) + } + cfg.graphite.resolvedAddr = addr.String() + cfg.graphite.prefix = c.GetString("stats.prefix", "nebula") + case "prometheus": + cfg.prom.listen = c.GetString("stats.listen", "") + if cfg.prom.listen == "" { + return cfg, errors.New("stats.listen should not be empty") + } + cfg.prom.path = c.GetString("stats.path", "") + if cfg.prom.path == "" { + return cfg, errors.New("stats.path should not be empty") + } + cfg.prom.namespace = c.GetString("stats.namespace", "") + cfg.prom.subsystem = c.GetString("stats.subsystem", "") + default: + return cfg, fmt.Errorf("stats.type was not understood: %s", cfg.typ) + } + + return cfg, nil } diff --git a/stats_test.go b/stats_test.go new file mode 100644 index 00000000..20b17c0e --- /dev/null +++ b/stats_test.go @@ -0,0 +1,410 @@ +package nebula + +import ( + "context" + "io" + "log/slog" + "net" + "strconv" + "testing" + "time" + + "github.com/slackhq/nebula/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestStatsServer(t *testing.T) (*statsServer, *config.C) { + t.Helper() + l := slog.New(slog.DiscardHandler) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return &statsServer{ + l: l, + ctx: ctx, + }, config.NewC(l) +} + +func setStatsConfig(c *config.C, m map[string]any) { + c.Settings["stats"] = m +} + +func currentRuntime(s *statsServer) *statsRuntime { + s.runMu.Lock() + defer s.runMu.Unlock() + return s.run +} + +func TestStatsServer_reload_initial_disabled(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{"type": "none"}) + + require.NoError(t, s.reload(c, true)) + assert.False(t, s.enabled.Load()) + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_initial_invalidInterval(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "graphite", + "host": "127.0.0.1:0", + "prefix": "test", + }) + + err := s.reload(c, true) + require.Error(t, err) + assert.False(t, s.enabled.Load()) +} + +func TestStatsServer_reload_initial_unknownType(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "carbon", + "interval": "1s", + }) + + err := s.reload(c, true) + require.Error(t, err) + assert.False(t, s.enabled.Load()) +} + +func TestStatsServer_reload_unchanged_noOp(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{"type": "none"}) + + require.NoError(t, s.reload(c, true)) + require.NoError(t, s.reload(c, false)) + assert.False(t, s.enabled.Load()) +} + +func TestStatsServer_reload_initial_graphite(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "graphite", + "interval": "1s", + "protocol": "tcp", + "host": "127.0.0.1:2003", + "prefix": "test", + }) + + require.NoError(t, s.reload(c, true)) + assert.True(t, s.enabled.Load()) + // reload only records config; Start builds the runtime. + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_initial_prometheus(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:0", + "path": "/metrics", + }) + + require.NoError(t, s.reload(c, true)) + assert.True(t, s.enabled.Load()) + // reload only records config; Start builds the runtime. + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_Start_graphite_blocksUntilStop(t *testing.T) { + sink := newGraphiteSink(t) + defer sink.Close() + + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "graphite", + "interval": "1s", + "protocol": "tcp", + "host": sink.Addr(), + "prefix": "test", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + + // Wait for Start to publish runtime state. + waitFor(t, func() bool { return currentRuntime(s) != nil }) + rt := currentRuntime(s) + require.NotNil(t, rt) + assert.Nil(t, rt.listener, "graphite has no listener") + + s.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("graphite Start did not return after Stop") + } + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_StartStop_lifecycle(t *testing.T) { + port := freeTCPPort(t) + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + + waitForListening(t, "127.0.0.1:"+port) + rt := currentRuntime(s) + require.NotNil(t, rt) + require.NotNil(t, rt.listener) + + s.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_disable_stopsRunningRuntime(t *testing.T) { + port := freeTCPPort(t) + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + waitForListening(t, "127.0.0.1:"+port) + + setStatsConfig(c, map[string]any{"type": "none"}) + require.NoError(t, s.reload(c, false)) + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after reload disabled stats") + } + assert.False(t, s.enabled.Load()) + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_changeListener_restartsListener(t *testing.T) { + port1 := freeTCPPort(t) + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port1, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + firstDone := make(chan struct{}) + go func() { + s.Start() + close(firstDone) + }() + waitForListening(t, "127.0.0.1:"+port1) + first := currentRuntime(s) + require.NotNil(t, first) + + port2 := freeTCPPort(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port2, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, false)) + + select { + case <-firstDone: + case <-time.After(5 * time.Second): + t.Fatal("old Start did not return after reload") + } + + waitForListening(t, "127.0.0.1:"+port2) + second := currentRuntime(s) + require.NotNil(t, second) + assert.NotSame(t, first, second, "expected a new runtime after listen address change") + + s.Stop() +} + +func TestStatsServer_Stop_beforeStart_doesNotBlock(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:0", + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + stopped := make(chan struct{}) + go func() { + s.Stop() + close(stopped) + }() + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatal("Stop hung with no runtime started") + } +} + +func TestStatsServer_configTest_validatesWithoutSpawning(t *testing.T) { + s, c := newTestStatsServer(t) + s.configTest = true + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:0", + "path": "/metrics", + }) + + require.NoError(t, s.reload(c, true)) + s.Start() + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_ctxCancel_unblocksStart(t *testing.T) { + // Ensures ctx cancellation alone (no explicit Stop) tears down both + // graphite and prom Start invocations. + port := freeTCPPort(t) + l := slog.New(slog.DiscardHandler) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := &statsServer{l: l, ctx: ctx} + c := config.NewC(l) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + waitForListening(t, "127.0.0.1:"+port) + + cancel() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after ctx cancel") + } +} + +func TestStatsServer_listenerBindFailure_sameCfgReloadRetries(t *testing.T) { + // Hold the port so ListenAndServe will fail on first Start. + blocker, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := strconv.Itoa(blocker.Addr().(*net.TCPAddr).Port) + + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after bind failure") + } + // Bind failure should have dropped the cached config so a same-cfg + // SIGHUP can retry. + s.runMu.Lock() + cfgAfterFailure := s.runCfg + s.runMu.Unlock() + assert.Nil(t, cfgAfterFailure) + + // Free the port and reload with the same config; Start should fire again. + require.NoError(t, blocker.Close()) + require.NoError(t, s.reload(c, false)) + + waitForListening(t, "127.0.0.1:"+port) + require.NotNil(t, currentRuntime(s)) + + s.Stop() +} + +func waitForListening(t *testing.T, addr string) { + t.Helper() + waitFor(t, func() bool { + conn, err := net.DialTimeout("tcp", addr, 200*time.Millisecond) + if err != nil { + return false + } + _ = conn.Close() + return true + }) +} + +// graphiteSink is a minimal TCP accept-and-discard server so graphite.Once +// calls in tests don't spam error logs or wedge on connection refused. +type graphiteSink struct { + ln net.Listener +} + +func newGraphiteSink(t *testing.T) *graphiteSink { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + g := &graphiteSink{ln: ln} + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + _, _ = io.Copy(io.Discard, c) + _ = c.Close() + }(conn) + } + }() + return g +} + +func (g *graphiteSink) Addr() string { return g.ln.Addr().String() } +func (g *graphiteSink) Close() { _ = g.ln.Close() } + +func freeTCPPort(t *testing.T) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := ln.Addr().(*net.TCPAddr).Port + require.NoError(t, ln.Close()) + return strconv.Itoa(port) +} diff --git a/test/logger.go b/test/logger.go index b5a717d8..faab0b69 100644 --- a/test/logger.go +++ b/test/logger.go @@ -1,29 +1,73 @@ package test import ( + "context" "io" + "log/slog" "os" + "time" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/logging" ) -func NewLogger() *logrus.Logger { - l := logrus.New() - +// NewLogger returns a *slog.Logger suitable for use in tests. Output goes to +// io.Discard by default; set TEST_LOGS=1 (info), 2 (debug), or 3 (trace) to +// stream output to stderr for local debugging. +func NewLogger() *slog.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(io.Discard) - return l + return slog.New(slog.DiscardHandler) } + level := slog.LevelInfo switch v { case "2": - l.SetLevel(logrus.DebugLevel) + level = slog.LevelDebug case "3": - l.SetLevel(logrus.TraceLevel) - default: - l.SetLevel(logrus.InfoLevel) + level = logging.LevelTrace } - - return l + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} + +// NewLoggerWithOutput returns a *slog.Logger whose text output is captured by +// w. Timestamps are suppressed so tests can assert on exact output without +// baking the current time into expected strings. +func NewLoggerWithOutput(w io.Writer) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, nil)}) +} + +// NewLoggerWithOutputAndLevel is NewLoggerWithOutput with an explicit level +// so tests can exercise Enabled-gated paths. +func NewLoggerWithOutputAndLevel(w io.Writer, level slog.Level) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, &slog.HandlerOptions{Level: level})}) +} + +// NewJSONLoggerWithOutput returns a *slog.Logger emitting JSON to w with +// timestamps suppressed, for tests that pin the JSON shape. +func NewJSONLoggerWithOutput(w io.Writer, level slog.Level) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewJSONHandler(w, &slog.HandlerOptions{Level: level})}) +} + +// stripTimeHandler zeros each record's time before delegating so slog's +// built-in handlers skip emitting the time attribute. Used to avoid +// timestamp-dependent assertions in tests without resorting to ReplaceAttr. +type stripTimeHandler struct { + inner slog.Handler +} + +func (h *stripTimeHandler) Enabled(ctx context.Context, l slog.Level) bool { + return h.inner.Enabled(ctx, l) +} + +func (h *stripTimeHandler) Handle(ctx context.Context, r slog.Record) error { + r.Time = time.Time{} + return h.inner.Handle(ctx, r) +} + +func (h *stripTimeHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &stripTimeHandler{inner: h.inner.WithAttrs(attrs)} +} + +func (h *stripTimeHandler) WithGroup(name string) slog.Handler { + return &stripTimeHandler{inner: h.inner.WithGroup(name)} } diff --git a/timeout_test.go b/timeout_test.go index db36fec7..ffeecc55 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -134,7 +134,7 @@ func TestTimerWheel_Purge(t *testing.T) { assert.True(t, tw.lastTick.After(lastTick)) // Make sure we get all 4 packets back - for i := 0; i < 4; i++ { + for i := range 4 { p, has := tw.Purge() assert.True(t, has) assert.Equal(t, fps[i], p) @@ -149,7 +149,7 @@ func TestTimerWheel_Purge(t *testing.T) { // Make sure we cached the free'd items assert.Equal(t, 4, tw.itemsCached) ci := tw.itemCache - for i := 0; i < 4; i++ { + for range 4 { assert.NotNil(t, ci) ci = ci.Next } diff --git a/udp/conn.go b/udp/conn.go index 1ae585c2..30d89dec 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -16,7 +16,7 @@ type EncReader func( type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) + ListenOut(r EncReader) error WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool @@ -31,8 +31,8 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) { - return +func (NoopConn) ListenOut(_ EncReader) error { + return nil } func (NoopConn) SupportsMultipleReaders() bool { return false diff --git a/udp/udp_android.go b/udp/udp_android.go index bb191954..3fc68003 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -9,11 +9,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index 65ef31a5..c42a3c18 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -12,11 +12,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 91201194..8a4f5b18 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -8,12 +8,12 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net" "net/netip" "syscall" "unsafe" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) @@ -22,12 +22,12 @@ type StdConn struct { *net.UDPConn isV4 bool sysFd uintptr - l *logrus.Logger + l *slog.Logger } var _ Conn = &StdConn{} -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) for { @@ -173,11 +173,10 @@ func (u *StdConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } - u.l.WithError(err).Error("unexpected udp socket receive error") + u.l.Error("unexpected udp socket receive error", "error", err) } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) @@ -197,7 +196,7 @@ func (u *StdConn) Rebind() error { } if err != nil { - u.l.WithError(err).Error("Failed to rebind udp socket") + u.l.Error("Failed to rebind udp socket", "error", err) } return nil diff --git a/udp/udp_generic.go b/udp/udp_generic.go index e9dad6c5..131eb73b 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -12,22 +12,22 @@ import ( "context" "errors" "fmt" + "log/slog" "net" "net/netip" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) type GenericConn struct { *net.UDPConn - l *logrus.Logger + l *slog.Logger } var _ Conn = &GenericConn{} -func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewGenericListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -73,7 +73,7 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) { +func (u *GenericConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -83,13 +83,12 @@ func (u *GenericConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() - u.l.WithError(err).Warn("unexpected udp socket receive error") + u.l.Warn("unexpected udp socket receive error", "error", err) } continue } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e7759329..3e2d726a 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -4,72 +4,73 @@ package udp import ( + "context" "encoding/binary" "fmt" + "log/slog" "net" "net/netip" "syscall" "unsafe" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int + udpConn *net.UDPConn + rawConn syscall.RawConn + isV4 bool + l *slog.Logger + batch int } -func maybeIPV4(ip net.IP) (net.IP, bool) { - ip4 := ip.To4() - if ip4 != nil { - return ip4, true - } - return ip, false -} - -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - af := unix.AF_INET6 - if ip.Is4() { - af = unix.AF_INET - } - syscall.ForkLock.RLock() - fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) - if err == nil { - unix.CloseOnExec(fd) - } - syscall.ForkLock.RUnlock() - +func setReusePort(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + //CloseOnExec already set by the runtime + }) + if err != nil { + return err + } + return opErr +} + +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { + listen := netip.AddrPortFrom(ip, uint16(port)) + lc := net.ListenConfig{} + if multi { + lc.Control = setReusePort + } + //this context is only used during the bind operation, you can't cancel it to kill the socket + pc, err := lc.ListenPacket(context.Background(), "udp", listen.String()) if err != nil { - unix.Close(fd) return nil, fmt.Errorf("unable to open socket: %s", err) } - - if multi { - if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { - return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) - } + udpConn := pc.(*net.UDPConn) + rawConn, err := udpConn.SyscallConn() + if err != nil { + _ = udpConn.Close() + return nil, err + } + //gotta find out if we got an AF_INET6 socket or not: + out := &StdConn{ + udpConn: udpConn, + rawConn: rawConn, + l: l, + batch: batch, } - var sa unix.Sockaddr - if ip.Is4() { - sa4 := &unix.SockaddrInet4{Port: port} - sa4.Addr = ip.As4() - sa = sa4 - } else { - sa6 := &unix.SockaddrInet6{Port: port} - sa6.Addr = ip.As16() - sa = sa6 - } - if err = unix.Bind(fd, sa); err != nil { - return nil, fmt.Errorf("unable to bind to socket: %s", err) + af, err := out.getSockOptInt(unix.SO_DOMAIN) + if err != nil { + _ = out.Close() + return nil, err } + out.isV4 = af == unix.AF_INET - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return out, nil } func (u *StdConn) SupportsMultipleReaders() bool { @@ -80,62 +81,133 @@ func (u *StdConn) Rebind() error { return nil } +func (u *StdConn) getSockOptInt(opt int) (int, error) { + if u.rawConn == nil { + return 0, fmt.Errorf("no UDP connection") + } + var out int + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + out, opErr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, opt) + }) + if err != nil { + return 0, err + } + return out, opErr +} + +func (u *StdConn) setSockOptInt(opt int, n int) error { + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, opt, n) + }) + if err != nil { + return err + } + return opErr +} + func (u *StdConn) SetRecvBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) + return u.setSockOptInt(unix.SO_RCVBUFFORCE, n) } func (u *StdConn) SetSendBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) + return u.setSockOptInt(unix.SO_SNDBUFFORCE, n) } func (u *StdConn) SetSoMark(mark int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) + return u.setSockOptInt(unix.SO_MARK, mark) } func (u *StdConn) GetRecvBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) + return u.getSockOptInt(unix.SO_RCVBUF) } func (u *StdConn) GetSendBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) + return u.getSockOptInt(unix.SO_SNDBUF) } func (u *StdConn) GetSoMark() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) + return u.getSockOptInt(unix.SO_MARK) } func (u *StdConn) LocalAddr() (netip.AddrPort, error) { - sa, err := unix.Getsockname(u.sysFd) - if err != nil { - return netip.AddrPort{}, err - } + a := u.udpConn.LocalAddr() - switch sa := sa.(type) { - case *unix.SockaddrInet4: - return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil - - case *unix.SockaddrInet6: - return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil + switch v := a.(type) { + case *net.UDPAddr: + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: - return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } -func (u *StdConn) ListenOut(r EncReader) { +func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { + var errno syscall.Errno + n, _, errno := unix.Syscall6( + unix.SYS_RECVMMSG, + fd, + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + unix.MSG_WAITFORONE, + 0, + 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + // No data available, block for I/O and try again. + return int(n), false, nil + } + if errno != 0 { + return int(n), true, &net.OpError{Op: "recvmmsg", Err: errno} + } + return int(n), true, nil +} + +func (u *StdConn) listenOutSingle(r EncReader) error { + var err error + var n int + var from netip.AddrPort + buffer := make([]byte, MTU) + + for { + n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer) + if err != nil { + return err + } + from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) + r(from, buffer[:n]) + } +} + +func (u *StdConn) listenOutBatch(r EncReader) error { var ip netip.Addr + var n int + var operr error msgs, buffers, names := u.PrepareRawMessages(u.batch) - read := u.ReadMulti - if u.batch == 1 { - read = u.ReadSingle + + //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read + //defining it outside the loop so it gets re-used + reader := func(fd uintptr) (done bool) { + n, done, operr = recvmmsg(fd, msgs) + return done } for { - n, err := read(msgs) + err := u.rawConn.Read(reader) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err + } + if operr != nil { + return operr } for i := 0; i < n; i++ { @@ -150,106 +222,17 @@ func (u *StdConn) ListenOut(r EncReader) { } } -func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&(msgs[0].Hdr))), - 0, - 0, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmsg", Err: err} - } - - msgs[0].Len = uint32(n) - return 1, nil - } -} - -func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), - unix.MSG_WAITFORONE, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmmsg", Err: err} - } - - return int(n), nil +func (u *StdConn) ListenOut(r EncReader) error { + if u.batch == 1 { + return u.listenOutSingle(r) + } else { + return u.listenOutBatch(r) } } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { - if u.isV4 { - return u.writeTo4(b, ip) - } - return u.writeTo6(b, ip) -} - -func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { - var rsa unix.RawSockaddrInet6 - rsa.Family = unix.AF_INET6 - rsa.Addr = ip.Addr().As16() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet6), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } -} - -func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { - if !ip.Addr().Is4() { - return ErrInvalidIPv6RemoteForSocket - } - - var rsa unix.RawSockaddrInet4 - rsa.Family = unix.AF_INET - rsa.Addr = ip.Addr().As4() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet4), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } + _, err := u.udpConn.WriteToUDPAddrPort(b, ip) + return err } func (u *StdConn) ReloadConfig(c *config.C) { @@ -259,12 +242,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetRecvBuffer() if err == nil { - u.l.WithField("size", s).Info("listen.read_buffer was set") + u.l.Info("listen.read_buffer was set", "size", s) } else { - u.l.WithError(err).Warn("Failed to get listen.read_buffer") + u.l.Warn("Failed to get listen.read_buffer", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.read_buffer") + u.l.Error("Failed to set listen.read_buffer", "error", err) } } @@ -274,12 +257,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetSendBuffer() if err == nil { - u.l.WithField("size", s).Info("listen.write_buffer was set") + u.l.Info("listen.write_buffer was set", "size", s) } else { - u.l.WithError(err).Warn("Failed to get listen.write_buffer") + u.l.Warn("Failed to get listen.write_buffer", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.write_buffer") + u.l.Error("Failed to set listen.write_buffer", "error", err) } } @@ -290,27 +273,40 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetSoMark() if err == nil { - u.l.WithField("mark", s).Info("listen.so_mark was set") + u.l.Info("listen.so_mark was set", "mark", s) } else { - u.l.WithError(err).Warn("Failed to get listen.so_mark") + u.l.Warn("Failed to get listen.so_mark", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.so_mark") + u.l.Error("Failed to set listen.so_mark", "error", err) } } } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { var vallen uint32 = 4 * unix.SK_MEMINFO_VARS - _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) - if err != 0 { + + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + _, _, syserr := unix.Syscall6(unix.SYS_GETSOCKOPT, fd, uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) + if syserr != 0 { + opErr = syserr + } + }) + if err != nil { return err } - return nil + return opErr } func (u *StdConn) Close() error { - return syscall.Close(u.sysFd) + if u.udpConn != nil { + return u.udpConn.Close() + } + return nil } func NewUDPStatsEmitter(udpConns []Conn) func() { diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 3b69159a..4b2de75a 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -11,11 +11,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 3d60f34c..d110af19 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/netip" "sync" @@ -17,7 +18,6 @@ import ( "time" "unsafe" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" @@ -53,14 +53,14 @@ type ringBuffer struct { type RIOConn struct { isOpen atomic.Bool - l *logrus.Logger + l *slog.Logger sock windows.Handle rx, tx ringBuffer rq winrio.Rq results [packetsPerRing]winrio.Result } -func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { +func NewRIOListener(l *slog.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } @@ -83,7 +83,7 @@ func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, erro return u, nil } -func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { +func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error { var err error u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) if err != nil { @@ -103,7 +103,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. - l.WithError(err).Debug("failed to set UDP_CONNRESET ioctl") + l.Debug("failed to set UDP_CONNRESET ioctl", "error", err) } ret = 0 @@ -114,7 +114,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. - l.WithError(err).Debug("failed to set UDP_NETRESET ioctl") + l.Debug("failed to set UDP_NETRESET ioctl", "error", err) } err = u.rx.Open() @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) { +func (u *RIOConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -151,13 +151,12 @@ func (u *RIOConn) ListenOut(r EncReader) { if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() - u.l.WithError(err).Warn("unexpected udp socket receive error") + u.l.Warn("unexpected udp socket receive error", "error", err) } continue } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5f0f7765..f872e32a 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -4,11 +4,13 @@ package udp import ( + "context" "io" + "log/slog" "net/netip" - "sync/atomic" + "os" + "sync" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -19,32 +21,72 @@ type Packet struct { Data []byte } +// Copy returns a fresh *Packet (from the freelist) with a duplicate Data buffer. func (u *Packet) Copy() *Packet { - n := &Packet{ - To: u.To, - From: u.From, - Data: make([]byte, len(u.Data)), + n := acquirePacket() + n.To = u.To + n.From = u.From + if cap(n.Data) < len(u.Data) { + n.Data = make([]byte, len(u.Data)) + } else { + n.Data = n.Data[:len(u.Data)] } - copy(n.Data, u.Data) return n } +// Release returns p to the harness packet freelist. +// Callers that pull a *Packet from Get / TxPackets must Release when done. +// Channel-backed instead of sync.Pool because sync.Pool's per-P caches drain badly under cross-goroutine Get/Put, +// and putting a []byte in a Pool escapes the slice header to heap. +func (p *Packet) Release() { + if p == nil { + return + } + p.Data = p.Data[:0] + select { + case packetFreelist <- p: + default: + // Freelist full; drop the *Packet for the GC. + } +} + +// packetFreelist retains *Packet structs (and their backing Data arrays) so steady-state allocation drops to zero. +var packetFreelist = make(chan *Packet, 64) + +func acquirePacket() *Packet { + select { + case p := <-packetFreelist: + return p + default: + return &Packet{} + } +} + type TesterConn struct { Addr netip.AddrPort RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula - closed atomic.Bool - l *logrus.Logger + // done is closed exactly once by Close. Senders select on it so they + // never race with a channel close; readers exit when it fires. The + // packet channels are intentionally never closed - that was the source + // of `send on closed channel` panics when a WriteTo/Send from another + // goroutine passed the close check and reached the send just after + // Close ran. + done chan struct{} + closeOnce sync.Once + + l *slog.Logger } -func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), + done: make(chan struct{}), l: l, }, nil } @@ -53,21 +95,23 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *TesterConn) Send(packet *Packet) { - if u.closed.Load() { - return + if u.l.Enabled(context.Background(), slog.LevelDebug) { + // Parse the header only under debug logging, otherwise the + // allocation would show up in every Send call. + var h header.H + if err := h.Parse(packet.Data); err != nil { + panic(err) + } + u.l.Debug("UDP receiving injected packet", + "header", &h, + "udpAddr", packet.From, + "dataLen", len(packet.Data), + ) } - - h := &header.H{} - if err := h.Parse(packet.Data); err != nil { - panic(err) + select { + case <-u.done: + case u.RxPackets <- packet: } - if u.l.Level >= logrus.DebugLevel { - u.l.WithField("header", h). - WithField("udpAddr", packet.From). - WithField("dataLen", len(packet.Data)). - Debug("UDP receiving injected packet") - } - u.RxPackets <- packet } // Get will pull a UdpPacket from the transmit queue @@ -75,7 +119,12 @@ func (u *TesterConn) Send(packet *Packet) { // packets were ingested from the tun side (in most cases), you can send them with Tun.Send func (u *TesterConn) Get(block bool) *Packet { if block { - return <-u.TxPackets + select { + case <-u.done: + return nil + case p := <-u.TxPackets: + return p + } } select { @@ -91,28 +140,33 @@ func (u *TesterConn) Get(block bool) *Packet { //********************************************************************************************************************// func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { - if u.closed.Load() { - return io.ErrClosedPipe + p := acquirePacket() + if cap(p.Data) < len(b) { + p.Data = make([]byte, len(b)) + } else { + p.Data = p.Data[:len(b)] } - - p := &Packet{ - Data: make([]byte, len(b), len(b)), - From: u.Addr, - To: addr, - } - copy(p.Data, b) - u.TxPackets <- p - return nil + p.From = u.Addr + p.To = addr + select { + case <-u.done: + p.Release() + return io.ErrClosedPipe + case u.TxPackets <- p: + return nil + } } -func (u *TesterConn) ListenOut(r EncReader) { +func (u *TesterConn) ListenOut(r EncReader) error { for { - p, ok := <-u.RxPackets - if !ok { - return + select { + case <-u.done: + return os.ErrClosed + case p := <-u.RxPackets: + r(p.From, p.Data) + p.Release() } - r(p.From, p.Data) } } @@ -136,9 +190,8 @@ func (u *TesterConn) Rebind() error { } func (u *TesterConn) Close() error { - if u.closed.CompareAndSwap(false, true) { - close(u.RxPackets) - close(u.TxPackets) - } + u.closeOnce.Do(func() { + close(u.done) + }) return nil } diff --git a/udp/udp_windows.go b/udp/udp_windows.go index 1b777c37..7969f7e8 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -5,14 +5,13 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - - "github.com/sirupsen/logrus" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between @@ -25,7 +24,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return rc, nil } - l.WithError(err).Error("Falling back to standard udp sockets") + l.Error("Falling back to standard udp sockets", "error", err) return NewGenericListener(l, ip, port, multi, batch) } diff --git a/util/error.go b/util/error.go index 814c77a1..14371d3f 100644 --- a/util/error.go +++ b/util/error.go @@ -1,10 +1,10 @@ package util import ( + "context" "errors" "fmt" - - "github.com/sirupsen/logrus" + "log/slog" ) type ContextualError struct { @@ -28,12 +28,12 @@ func ContextualizeIfNeeded(msg string, err error) error { } // LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError -func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { +func LogWithContextIfNeeded(msg string, err error, l *slog.Logger) { switch v := err.(type) { case *ContextualError: v.Log(l) default: - l.WithError(err).Error(msg) + l.Error(msg, "error", err) } } @@ -51,10 +51,19 @@ func (ce *ContextualError) Unwrap() error { return ce.RealError } -func (ce *ContextualError) Log(lr *logrus.Logger) { - if ce.RealError != nil { - lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) - } else { - lr.WithFields(ce.Fields).Error(ce.Context) +// Log emits ce as a single error-level log line with Fields and RealError +// promoted to top-level attributes, producing a flat shape callers can grep +// or parse without walking into a nested object. +func (ce *ContextualError) Log(l *slog.Logger) { + attrs := make([]slog.Attr, 0, len(ce.Fields)+1) + for k, v := range ce.Fields { + attrs = append(attrs, slog.Any(k, v)) } + if ce.RealError != nil { + attrs = append(attrs, slog.Any("error", ce.RealError)) + } + // LogAttrs is intentional: attrs is built from a map[string]any so it has + // no pair-form equivalent. + //nolint:sloglint + l.LogAttrs(context.Background(), slog.LevelError, ce.Context, attrs...) } diff --git a/util/error_test.go b/util/error_test.go index 692c1840..30e39e33 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -1,95 +1,67 @@ package util import ( + "bytes" "errors" "fmt" "testing" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) type m = map[string]any -type TestLogWriter struct { - Logs []string -} - -func NewTestLogWriter() *TestLogWriter { - return &TestLogWriter{Logs: make([]string, 0)} -} - -func (tl *TestLogWriter) Write(p []byte) (n int, err error) { - tl.Logs = append(tl.Logs, string(p)) - return len(p), nil -} - -func (tl *TestLogWriter) Reset() { - tl.Logs = tl.Logs[:0] -} - func TestContextualError_Log(t *testing.T) { - l := logrus.New() - l.Formatter = &logrus.TextFormatter{ - DisableTimestamp: true, - DisableColors: true, - } - - tl := NewTestLogWriter() - l.Out = tl + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutput(buf) // Test a full context line - tl.Reset() + buf.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String()) // Test a line with an error and msg but no fields - tl.Reset() + buf.Reset() e = NewContextualError("test message", nil, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" error=error\n", buf.String()) // Test just a context and fields - tl.Reset() + buf.Reset() e = NewContextualError("test message", m{"field": "1"}, nil) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1\n", buf.String()) // Test just a context - tl.Reset() + buf.Reset() e = NewContextualError("test message", nil, nil) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\"\n", buf.String()) // Test just an error - tl.Reset() + buf.Reset() e = NewContextualError("", nil, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"\" error=error\n", buf.String()) } func TestLogWithContextIfNeeded(t *testing.T) { - l := logrus.New() - l.Formatter = &logrus.TextFormatter{ - DisableTimestamp: true, - DisableColors: true, - } - - tl := NewTestLogWriter() - l.Out = tl + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutput(buf) // Test ignoring fallback context - tl.Reset() + buf.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) LogWithContextIfNeeded("This should get thrown away", e, l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String()) // Test using fallback context - tl.Reset() + buf.Reset() err := fmt.Errorf("this is a normal error") LogWithContextIfNeeded("Fallback context woo", err, l) - assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"Fallback context woo\" error=\"this is a normal error\"\n", buf.String()) } func TestContextualizeIfNeeded(t *testing.T) {