diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9ce1d5e3..a5e8d397 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -209,10 +209,11 @@ jobs: id: create_release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REF_NAME: ${{ github.ref_name }} run: | cd artifacts gh release create \ --verify-tag \ - --title "Release ${{ github.ref_name }}" \ - "${{ github.ref_name }}" \ + --title "Release ${GITHUB_REF_NAME}" \ + "${GITHUB_REF_NAME}" \ SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index 24f899ab..3734db75 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -18,6 +18,8 @@ jobs: if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') name: Run extra smoke tests runs-on: ubuntu-latest + env: + VAGRANT_DEFAULT_PROVIDER: libvirt steps: - uses: actions/checkout@v6 @@ -30,8 +32,13 @@ jobs: - name: add hashicorp source run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list - - name: install vagrant - run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox + - name: install vagrant and libvirt + run: | + sudo apt-get update && sudo apt-get install -y vagrant libvirt-daemon-system libvirt-dev + sudo chmod 666 /dev/kvm + sudo usermod -aG libvirt $(whoami) + sudo chmod 666 /var/run/libvirt/libvirt-sock + vagrant plugin install vagrant-libvirt - name: freebsd-amd64 run: make smoke-vagrant/freebsd-amd64 @@ -42,10 +49,19 @@ jobs: - name: netbsd-amd64 run: make smoke-vagrant/netbsd-amd64 - - name: linux-386 - run: make smoke-vagrant/linux-386 - - name: linux-amd64-ipv6disable run: make smoke-vagrant/linux-amd64-ipv6disable + # linux-386 runs last because it requires disabling KVM to use VirtualBox, + # which prevents libvirt (used by the other tests) from working after this point. + - name: install virtualbox for i386 test + run: | + sudo apt-get install -y virtualbox + sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true + + - name: linux-386 + env: + VAGRANT_DEFAULT_PROVIDER: virtualbox + run: make smoke-vagrant/linux-386 + timeout-minutes: 30 diff --git a/.github/workflows/smoke/build-relay.sh b/.github/workflows/smoke/build-relay.sh index 70b07f4e..249e6c84 100755 --- a/.github/workflows/smoke/build-relay.sh +++ b/.github/workflows/smoke/build-relay.sh @@ -16,8 +16,10 @@ relay: am_relay: true EOF - export LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" - export REMOTE_ALLOW_LIST='{"172.17.0.4/32": false, "172.17.0.5/32": false}' + # TEST-NET-3 placeholder IPs; smoke-relay.sh seds them to real container IPs. + # Mapping: .2 lighthouse1, .3 host2, .4 host3, .5 host4. + export LIGHTHOUSES="192.168.100.1 203.0.113.2:4242" + export REMOTE_ALLOW_LIST='{"203.0.113.4/32": false, "203.0.113.5/32": false}' HOST="host2" ../genconfig.sh >host2.yml <host3.yml diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh index dcd132b0..b23516ee 100755 --- a/.github/workflows/smoke/build.sh +++ b/.github/workflows/smoke/build.sh @@ -5,9 +5,15 @@ set -e -x rm -rf ./build mkdir ./build -# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1 -# - We could make this better by launching the lighthouse first and then fetching what IP it is. -NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)" +# Smoke containers run on a dedicated docker network whose subnet is allocated +# at smoke time, not known at build time. Configs are written with TEST-NET-3 +# placeholder IPs (RFC 5737) and smoke.sh / smoke-vagrant.sh / smoke-relay.sh +# sed the real container IPs in before starting nebula. +# +# Placeholder mapping (last octet == fixed container slot): +# 203.0.113.2 -> lighthouse1, 203.0.113.3 -> host2, +# 203.0.113.4 -> host3, 203.0.113.5 -> host4. +LIGHTHOUSE_IP="203.0.113.2" ( cd build @@ -25,16 +31,16 @@ NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ ../genconfig.sh >lighthouse1.yml HOST="host2" \ - LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ + LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \ ../genconfig.sh >host2.yml HOST="host3" \ - LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ + LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \ INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host3.yml HOST="host4" \ - LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ + LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \ OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host4.yml diff --git a/.github/workflows/smoke/smoke-relay.sh b/.github/workflows/smoke/smoke-relay.sh index 9c113e18..aa1cd915 100755 --- a/.github/workflows/smoke/smoke-relay.sh +++ b/.github/workflows/smoke/smoke-relay.sh @@ -6,6 +6,8 @@ set -o pipefail mkdir -p logs +NETWORK="nebula-smoke-relay" + cleanup() { echo echo " *** cleanup" @@ -16,22 +18,53 @@ cleanup() { then docker kill lighthouse1 host2 host3 host4 fi + docker network rm "$NETWORK" >/dev/null 2>&1 } trap cleanup EXIT -docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test -docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test -docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test -docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test +# Create a dedicated smoke network with an explicit subnet (required for --ip +# below). Probe a short list of candidates so a locally-used range doesn't +# fail the whole test — we only need one to be free. +docker network rm "$NETWORK" >/dev/null 2>&1 || true +for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do + if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then + break + fi +done +if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then + echo "failed to create $NETWORK: every candidate subnet is in use" >&2 + exit 1 +fi -docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1, +# .3 host2, .4 host3, .5 host4 — matches the placeholders in build-relay.sh. +SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")" +PREFIX="${SUBNET%/*}" +PREFIX="${PREFIX%.*}" +LIGHTHOUSE_IP="$PREFIX.2" +HOST2_IP="$PREFIX.3" +HOST3_IP="$PREFIX.4" +HOST4_IP="$PREFIX.5" + +# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones. +for f in build/host2.yml build/host3.yml build/host4.yml; do + sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp" + mv "$f.tmp" "$f" +done + +docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test +docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" nebula:smoke-relay -config host2.yml -test +docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" nebula:smoke-relay -config host3.yml -test +docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" nebula:smoke-relay -config host4.yml -test + +docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 set +x @@ -76,7 +109,13 @@ docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' -sleep 5 + +# Wait up to 30s for all backgrounded jobs to exit rather than relying on a +# fixed sleep. +for _ in $(seq 1 30); do + [ -z "$(jobs -r)" ] && break + sleep 1 +done if [ "$(jobs -r)" ] then diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh index 1c1e3c50..e3863cb5 100755 --- a/.github/workflows/smoke/smoke-vagrant.sh +++ b/.github/workflows/smoke/smoke-vagrant.sh @@ -8,6 +8,8 @@ export VAGRANT_CWD="$PWD/vagrant-$1" mkdir -p logs +NETWORK="nebula-smoke" + cleanup() { echo echo " *** cleanup" @@ -19,21 +21,51 @@ cleanup() { docker kill lighthouse1 host2 fi vagrant destroy -f + docker network rm "$NETWORK" >/dev/null 2>&1 } trap cleanup EXIT +# Create a dedicated smoke network with an explicit subnet (required for --ip +# below). Probe a short list of candidates so a locally-used range doesn't +# fail the whole test — we only need one to be free. +docker network rm "$NETWORK" >/dev/null 2>&1 || true +for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do + if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then + break + fi +done +if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then + echo "failed to create $NETWORK: every candidate subnet is in use" >&2 + exit 1 +fi + +# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1, +# .3 host2 — matches the placeholders in build.sh. +SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")" +PREFIX="${SUBNET%/*}" +PREFIX="${PREFIX%.*}" +LIGHTHOUSE_IP="$PREFIX.2" +HOST2_IP="$PREFIX.3" + +# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones. +# This must happen before `vagrant up` rsyncs build/ into the VM for host3. +for f in build/host2.yml build/host3.yml; do + sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp" + mv "$f.tmp" "$f" +done + CONTAINER="nebula:${NAME:-smoke}" docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test -docker run --name host2 --rm "$CONTAINER" -config host2.yml -test +docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test vagrant up vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T -docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 15 @@ -96,7 +128,14 @@ vagrant ssh -c "ping -c1 192.168.100.2" -- -T vagrant ssh -c "sudo xargs kill /dev/null 2>&1 } trap cleanup EXIT +# Create a dedicated smoke network with an explicit subnet (required for --ip +# below). Probe a short list of candidates so a locally-used range doesn't +# fail the whole test — we only need one to be free. +docker network rm "$NETWORK" >/dev/null 2>&1 || true +for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do + if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then + break + fi +done +if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then + echo "failed to create $NETWORK: every candidate subnet is in use" >&2 + exit 1 +fi + +# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1, +# .3 host2, .4 host3, .5 host4 — matches the placeholders in build.sh. +SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")" +PREFIX="${SUBNET%/*}" +PREFIX="${PREFIX%.*}" +LIGHTHOUSE_IP="$PREFIX.2" +HOST2_IP="$PREFIX.3" +HOST3_IP="$PREFIX.4" +HOST4_IP="$PREFIX.5" + +# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones. +# build/lighthouse1.yml has no IPs to rewrite so it's skipped. +for f in build/host2.yml build/host3.yml build/host4.yml; do + sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp" + mv "$f.tmp" "$f" +done + CONTAINER="nebula:${NAME:-smoke}" docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test -docker run --name host2 --rm "$CONTAINER" -config host2.yml -test -docker run --name host3 --rm "$CONTAINER" -config host3.yml -test -docker run --name host4 --rm "$CONTAINER" -config host4.yml -test +docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test +docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" "$CONTAINER" -config host3.yml -test +docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" "$CONTAINER" -config host4.yml -test -docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 # grab tcpdump pcaps for debugging -docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & +docker exec lighthouse1 tcpdump -i tun0 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & -docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & +docker exec host2 tcpdump -i tun0 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & -docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & +docker exec host3 tcpdump -i tun0 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap & -docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & +docker exec host4 tcpdump -i tun0 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap & docker exec host2 ncat -nklv 0.0.0.0 2000 & docker exec host3 ncat -nklv 0.0.0.0 2000 & +docker exec host4 ncat -e '/usr/bin/echo helloagainfromhost4' -nkluv 0.0.0.0 4000 & docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 & @@ -119,17 +154,24 @@ echo echo " *** Testing conntrack" echo set -x -# host2 can ping host3 now that host3 pinged it first -docker exec host2 ping -c1 192.168.100.3 -# host4 can ping host2 once conntrack established -docker exec host2 ping -c1 192.168.100.4 -docker exec host4 ping -c1 192.168.100.2 + +# host4's outbound firewall only allows ICMP to the lighthouse, so host4 +# cannot initiate UDP to host2. Once host2 initiates a flow to host4:4000, +# conntrack must let host4's listener reply on that flow. If it doesn't, +# the echo back from host4 never reaches host2. +docker exec host2 sh -c "(/usr/bin/echo host2; sleep 2) | ncat -nuv 192.168.100.4 4000" | grep -q helloagainfromhost4 docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' -sleep 5 + +# Wait up to 30s for all backgrounded jobs to exit rather than relying on a +# fixed sleep. +for _ in $(seq 1 30); do + [ -z "$(jobs -r)" ] && break + sleep 1 +done if [ "$(jobs -r)" ] then diff --git a/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile index 89f94772..eeb9679e 100644 --- a/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile +++ b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile @@ -1,7 +1,7 @@ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| - config.vm.box = "ubuntu/jammy64" + config.vm.box = "bento/ubuntu-24.04" config.vm.synced_folder "../build", "/nebula" diff --git a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile index e4f41049..6dd26373 100644 --- a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile +++ b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile @@ -1,7 +1,7 @@ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| - config.vm.box = "generic/openbsd7" + config.vm.box = "DefinedNet/openbsd78" config.vm.synced_folder "../build", "/nebula", type: "rsync" end diff --git a/.golangci.yaml b/.golangci.yaml index bd82a952..be0513d4 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -2,7 +2,21 @@ version: "2" linters: default: none enable: + - sloglint - testifylint + settings: + sloglint: + # Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and + # the package-level slog equivalents. Use l.Log(ctx, level, ...) for + # custom levels instead of LogAttrs when you can. + # + # LogAttrs is also flagged by this rule because it takes ...slog.Attr; + # the few legitimate sites (where attrs is built up as a []slog.Attr) + # carry a //nolint:sloglint with rationale. + kv-only: true + # no-mixed-args is on by default: forbids mixing kv and attrs in one call. + # discard-handler is on by default (since Go 1.24): suggests + # slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil). exclusions: generated: lax presets: diff --git a/CHANGELOG.md b/CHANGELOG.md index 330a7f78..2ef7551f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.10.3] - 2026-02-06 + +### Security + +- Fix an issue where blocklist bypass is possible when using curve P256 since the signature can have 2 valid representations. + Both fingerprint representations will be tested against the blocklist. + Any newly issued P256 based certificates will have their signature clamped to the low-s form. + Nebula will assert the low-s signature form when validating certificates in a future version. [GHSA-69x3-g4r3-p962](https://github.com/slackhq/nebula/security/advisories/GHSA-69x3-g4r3-p962) + +### Changed + +- Improve error reporting if nebula fails to start due to a tun device naming issue. (#1588) + ## [1.10.2] - 2026-01-21 ### Fixed @@ -775,7 +788,8 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.2...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.3...HEAD +[1.10.3]: https://github.com/slackhq/nebula/releases/tag/v1.10.3 [1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2 [1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1 [1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0 diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..00cd7bd1 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +#ECCN:Open Source diff --git a/README.md b/README.md index 4dbc0956..f27e0b5b 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for docker pull nebulaoss/nebula ``` -#### Mobile +#### Mobile ([source code](https://github.com/DefinedNet/mobile_nebula)) - [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200) - [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1) @@ -76,6 +76,8 @@ Nebula was created to provide a mechanism for groups of hosts to communicate sec ## Getting started (quickly) +**Don't want to manage your own PKI and lighthouses?** [Managed Nebula](https://www.defined.net/) from Defined Networking handles all of this for you. + To set up a Nebula network, you'll need: #### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use. diff --git a/bits.go b/bits.go index af11cc48..5c8f902b 100644 --- a/bits.go +++ b/bits.go @@ -1,8 +1,10 @@ package nebula import ( + "context" + "log/slog" + "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" ) type Bits struct { @@ -30,7 +32,7 @@ func NewBits(bits uint64) *Bits { return b } -func (b *Bits) Check(l *logrus.Logger, i uint64) bool { +func (b *Bits) Check(l *slog.Logger, i uint64) bool { // If i is the next number, return true. if i > b.current { return true @@ -42,13 +44,16 @@ func (b *Bits) Check(l *logrus.Logger, i uint64) bool { } // Not within the window - if l.Level >= logrus.DebugLevel { - l.Debugf("rejected a packet (top) %d %d\n", b.current, i) + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("rejected a packet (top)", + "current", b.current, + "incoming", i, + ) } return false } -func (b *Bits) Update(l *logrus.Logger, i uint64) bool { +func (b *Bits) Update(l *slog.Logger, i uint64) bool { // If i is the next number, return true and update current. if i == b.current+1 { // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter @@ -87,9 +92,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // Check to see if it's a duplicate if i > b.current-b.length || i < b.length && b.current < b.length { if b.current == i || b.bits[i%b.length] == true { - if l.Level >= logrus.DebugLevel { - l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). - Debug("Receive window") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Receive window", + "accepted", false, + "currentCounter", b.current, + "incomingCounter", i, + "reason", "duplicate", + ) } b.dupeCounter.Inc(1) return false @@ -101,12 +110,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // In all other cases, fail and don't change current. b.outOfWindowCounter.Inc(1) - if l.Level >= logrus.DebugLevel { - l.WithField("accepted", false). - WithField("currentCounter", b.current). - WithField("incomingCounter", i). - WithField("reason", "nonsense"). - Debug("Receive window") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Receive window", + "accepted", false, + "currentCounter", b.current, + "incomingCounter", i, + "reason", "nonsense", + ) } return false } diff --git a/boring.go b/boring.go index 9cd9d37f..abe403fc 100644 --- a/boring.go +++ b/boring.go @@ -1,5 +1,4 @@ //go:build boringcrypto -// +build boringcrypto package nebula diff --git a/cert/ca_pool.go b/cert/ca_pool.go index 2bf480f2..792f8e66 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -1,11 +1,14 @@ package cert import ( + "bufio" + "bytes" + "encoding/pem" "errors" "fmt" + "io" "net/netip" "slices" - "strings" "time" ) @@ -29,22 +32,46 @@ func NewCAPool() *CAPool { // If the pool contains any expired certificates, an ErrExpired will be // returned along with the pool. The caller must handle any such errors. func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { + return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs)) +} + +// NewCAPoolFromPEMReader will create a new CA pool from the provided reader. +// The reader must contain a PEM-encoded set of nebula certificates. +func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) { pool := NewCAPool() - var err error + var expired bool - for { - caPEMs, err = pool.AddCAFromPEM(caPEMs) - if errors.Is(err, ErrExpired) { - expired = true - err = nil + + scanner := bufio.NewScanner(r) + scanner.Split(SplitPEM) + + for scanner.Scan() { + pemBytes := scanner.Bytes() + + block, rest := pem.Decode(pemBytes) + if len(bytes.TrimSpace(rest)) > 0 { + return nil, ErrInvalidPEMBlock } + if block == nil { + return nil, ErrInvalidPEMBlock + } + + c, err := unmarshalCertificateBlock(block) if err != nil { return nil, err } - if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { - break + + err = pool.AddCA(c) + if errors.Is(err, ErrExpired) { + expired = true + continue + } else if err != nil { + return nil, err } } + if err := scanner.Err(); err != nil { + return nil, ErrInvalidPEMBlock + } if expired { return pool, ErrExpired @@ -141,10 +168,23 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti return nil, err } + // Pre nebula v1.10.3 could generate signatures in either high or low s form and validation + // of signatures allowed for either. Nebula v1.10.3 and beyond clamps signature generation to low-s form + // but validation still allows for either. Since a change in the signature bytes affects the fingerprint, we + // need to test both forms until such a time comes that we enforce low-s form on signature validation. + fp2, err := CalculateAlternateFingerprint(c) + if err != nil { + return nil, fmt.Errorf("could not calculate alternate fingerprint to verify: %w", err) + } + if fp2 != "" && ncp.IsBlocklisted(fp2) { + return nil, ErrBlockListed + } + cc := CachedCertificate{ Certificate: c, InvertedGroups: make(map[string]struct{}), Fingerprint: fp, + fingerprint2: fp2, signerFingerprint: signer.Fingerprint, } @@ -158,6 +198,11 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti // VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and // is a cheaper operation to perform as a result. func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error { + // Check any available alternate fingerprint forms for this certificate, re P256 high-s/low-s + if c.fingerprint2 != "" && ncp.IsBlocklisted(c.fingerprint2) { + return ErrBlockListed + } + _, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint) return err } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index b0fdd5fb..ab173228 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -1,10 +1,14 @@ package cert import ( + "bytes" + "io" "net/netip" + "strings" "testing" "time" + "github.com/slackhq/nebula/cert/p256" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -111,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe assert.Len(t, ppppp.CAs, 1) } +// oneByteReader wraps a reader to return at most 1 byte per Read call, +// exercising the streaming accumulation logic in NewCAPoolFromPEMReader. +type oneByteReader struct { + r io.Reader +} + +func (o *oneByteReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + return o.r.Read(p[:1]) +} + +func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) { + pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil)) + require.NoError(t, err) + assert.Empty(t, pool.CAs) + + pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n ")) + require.NoError(t, err) + assert.Empty(t, pool.CAs) +} + +func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) { + ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + + bundle := append(pem1, pem2...) + pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)}) + require.NoError(t, err) + assert.Len(t, pool.CAs, 2) + + fp1, err := ca1.Fingerprint() + require.NoError(t, err) + fp2, err := ca2.Fingerprint() + require.NoError(t, err) + + assert.Contains(t, pool.CAs, fp1) + assert.Contains(t, pool.CAs, fp2) +} + +func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) { + _, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata")) + assert.ErrorIs(t, err, ErrInvalidPEMBlock) +} + +func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) { + _, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + + bundle := append(pem1, []byte("some trailing garbage")...) + _, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle)) + assert.ErrorIs(t, err, ErrInvalidPEMBlock) +} + func TestCertificateV1_Verify(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) @@ -170,6 +228,15 @@ func TestCertificateV1_VerifyP256(t *testing.T) { _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") + // Create a copy of the cert and swap to the alternate form for the signature + nc := c.Copy() + b, err := p256.Swap(c.Signature()) + require.NoError(t, err) + require.NoError(t, nc.(*certificateV1).setSignature(b)) + + _, err = caPool.VerifyCertificate(time.Now(), nc) + require.EqualError(t, err, "certificate is in the block list") + caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) @@ -187,7 +254,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { require.NoError(t, err) caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) + b, err = caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) @@ -196,7 +263,17 @@ func TestCertificateV1_VerifyP256(t *testing.T) { }) c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - _, err = caPool.VerifyCertificate(time.Now(), c) + cc, err := caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Reset the blocklist and block the alternate form fingerprint + caPool.ResetCertBlocklist() + caPool.BlocklistFingerprint(cc.fingerprint2) + err = caPool.VerifyCachedCertificate(time.Now(), cc) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + err = caPool.VerifyCachedCertificate(time.Now(), cc) require.NoError(t, err) } @@ -394,6 +471,15 @@ func TestCertificateV2_VerifyP256(t *testing.T) { _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") + // Create a copy of the cert and swap to the alternate form for the signature + nc := c.Copy() + b, err := p256.Swap(c.Signature()) + require.NoError(t, err) + require.NoError(t, nc.(*certificateV2).setSignature(b)) + + _, err = caPool.VerifyCertificate(time.Now(), nc) + require.EqualError(t, err, "certificate is in the block list") + caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) @@ -411,7 +497,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { require.NoError(t, err) caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) + b, err = caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) @@ -420,7 +506,17 @@ func TestCertificateV2_VerifyP256(t *testing.T) { }) c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - _, err = caPool.VerifyCertificate(time.Now(), c) + cc, err := caPool.VerifyCertificate(time.Now(), c) + require.NoError(t, err) + + // Reset the blocklist and block the alternate form fingerprint + caPool.ResetCertBlocklist() + caPool.BlocklistFingerprint(cc.fingerprint2) + err = caPool.VerifyCachedCertificate(time.Now(), cc) + require.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + err = caPool.VerifyCachedCertificate(time.Now(), cc) require.NoError(t, err) } diff --git a/cert/cert.go b/cert/cert.go index 9d40e625..01d775e5 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -4,6 +4,8 @@ import ( "fmt" "net/netip" "time" + + "github.com/slackhq/nebula/cert/p256" ) type Version uint8 @@ -110,6 +112,9 @@ type CachedCertificate struct { InvertedGroups map[string]struct{} Fingerprint string signerFingerprint string + + // A place to store a 2nd fingerprint if the certificate could have one, such as with P256 + fingerprint2 string } func (cc *CachedCertificate) String() string { @@ -152,3 +157,31 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific return c, nil } + +// CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates +// CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step. +func CalculateAlternateFingerprint(c Certificate) (string, error) { + if c.Curve() != Curve_P256 { + return "", nil + } + + nc := c.Copy() + b, err := p256.Swap(nc.Signature()) + if err != nil { + return "", err + } + + switch v := nc.(type) { + case *certificateV1: + err = v.setSignature(b) + case *certificateV2: + err = v.setSignature(b) + default: + return "", ErrUnknownVersion + } + + if err != nil { + return "", err + } + return nc.Fingerprint() +} diff --git a/cert/p256/p256.go b/cert/p256/p256.go new file mode 100644 index 00000000..dc609a35 --- /dev/null +++ b/cert/p256/p256.go @@ -0,0 +1,127 @@ +package p256 + +import ( + "crypto/elliptic" + "errors" + "math/big" + + "filippo.io/bigmod" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +var halfN = new(big.Int).Rsh(elliptic.P256().Params().N, 1) +var nMod *bigmod.Modulus + +func init() { + n, err := bigmod.NewModulus(elliptic.P256().Params().N.Bytes()) + if err != nil { + panic(err) + } + nMod = n +} + +func IsNormalized(sig []byte) (bool, error) { + r, s, err := parseSignature(sig) + if err != nil { + return false, err + } + return checkLowS(r, s), nil +} + +func checkLowS(_, s []byte) bool { + bigS := new(big.Int).SetBytes(s) + // Check if S <= (N/2), because we want to include the midpoint in the set of low-s + return bigS.Cmp(halfN) <= 0 +} + +func swap(r, s []byte) ([]byte, []byte, error) { + var err error + bigS, err := bigmod.NewNat().SetBytes(s, nMod) + if err != nil { + return nil, nil, err + } + sNormalized := nMod.Nat().Sub(bigS, nMod) + + result := sNormalized.Bytes(nMod) + for len(result) > 1 && result[0] == 0 { + result = result[1:] + } + + return r, result, nil +} + +func Normalize(sig []byte) ([]byte, error) { + r, s, err := parseSignature(sig) + if err != nil { + return nil, err + } + + if checkLowS(r, s) { + return sig, nil + } + + newR, newS, err := swap(r, s) + if err != nil { + return nil, err + } + + return encodeSignature(newR, newS) +} + +// Swap will change sig between its current form to the opposite high or low form. +func Swap(sig []byte) ([]byte, error) { + r, s, err := parseSignature(sig) + if err != nil { + return nil, err + } + + newR, newS, err := swap(r, s) + if err != nil { + return nil, err + } + + return encodeSignature(newR, newS) +} + +// parseSignature taken exactly from crypto/ecdsa/ecdsa.go +func parseSignature(sig []byte) (r, s []byte, err error) { + var inner cryptobyte.String + input := cryptobyte.String(sig) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(&r) || + !inner.ReadASN1Integer(&s) || + !inner.Empty() { + return nil, nil, errors.New("invalid ASN.1") + } + return r, s, nil +} + +func encodeSignature(r, s []byte) ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + addASN1IntBytes(b, r) + addASN1IntBytes(b, s) + }) + return b.Bytes() +} + +// addASN1IntBytes encodes in ASN.1 a positive integer represented as +// a big-endian byte slice with zero or more leading zeroes. +func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { + for len(bytes) > 0 && bytes[0] == 0 { + bytes = bytes[1:] + } + if len(bytes) == 0 { + b.SetError(errors.New("invalid integer")) + return + } + b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) { + if bytes[0]&0x80 != 0 { + c.AddUint8(0) + } + c.AddBytes(bytes) + }) +} diff --git a/cert/p256/p256_test.go b/cert/p256/p256_test.go new file mode 100644 index 00000000..486a7242 --- /dev/null +++ b/cert/p256/p256_test.go @@ -0,0 +1,28 @@ +package p256 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFlipping(t *testing.T) { + priv, err1 := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err1) + + out, err := ecdsa.SignASN1(rand.Reader, priv, []byte("big chungus")) + require.NoError(t, err) + + r, s, err := parseSignature(out) + require.NoError(t, err) + + r, s1, err := swap(r, s) + require.NoError(t, err) + r, s2, err := swap(r, s1) + require.NoError(t, err) + require.Equal(t, s, s2) + require.NotEqual(t, s, s1) +} diff --git a/cert/pem.go b/cert/pem.go index 8942c23a..84221b22 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -1,12 +1,66 @@ package cert import ( + "bytes" "encoding/pem" + "errors" "fmt" "golang.org/x/crypto/ed25519" ) +var ErrTruncatedPEMBlock = errors.New("truncated PEM block") + +// SplitPEM is a split function for bufio.Scanner that returns each PEM block. +func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Look for the start of a PEM block + start := bytes.Index(data, []byte("-----BEGIN ")) + if start == -1 { + if atEOF && len(bytes.TrimSpace(data)) > 0 { + // Non-whitespace content with no PEM block + return 0, nil, ErrTruncatedPEMBlock + } + if atEOF { + return len(data), nil, nil + } + // Request more data + return 0, nil, nil + } + + // Look for the end marker + endMarkerStart := bytes.Index(data[start:], []byte("-----END ")) + if endMarkerStart == -1 { + if atEOF { + // Incomplete PEM block at EOF + return 0, nil, ErrTruncatedPEMBlock + } + // Need more data to find the end + return 0, nil, nil + } + + // Find the actual end of the END line (after the newline) + endMarkerStart += start + endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n') + var end int + if endLineEnd == -1 { + if atEOF { + // END marker without newline at EOF - take it anyway + end = len(data) + } else { + // Need more data + return 0, nil, nil + } + } else { + end = endMarkerStart + endLineEnd + 1 + } + + // Extract the PEM block + pemBlock := data[start:end] + + // Return the valid PEM block + return end, pemBlock, nil +} + const ( //cert banners CertificateBanner = "NEBULA CERTIFICATE" CertificateV2Banner = "NEBULA CERTIFICATE V2" @@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { return nil, r, ErrInvalidPEMBlock } - var c Certificate - var err error - - switch p.Type { - // Implementations must validate the resulting certificate contains valid information - case CertificateBanner: - c, err = unmarshalCertificateV1(p.Bytes, nil) - case CertificateV2Banner: - c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) - default: - return nil, r, ErrInvalidPEMCertificateBanner - } - + c, err := unmarshalCertificateBlock(p) if err != nil { return nil, r, err } @@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { } +// unmarshalCertificateBlock decodes a single PEM block into a certificate. +// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise. +func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) { + switch block.Type { + // Implementations must validate the resulting certificate contains valid information + case CertificateBanner: + return unmarshalCertificateV1(block.Bytes, nil) + case CertificateV2Banner: + return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519) + default: + return nil, ErrInvalidPEMCertificateBanner + } +} + func marshalCertPublicKeyToPEM(c Certificate) []byte { if c.IsCA() { return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey()) diff --git a/cert/pem_test.go b/cert/pem_test.go index 310c57a3..ff623541 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -1,12 +1,88 @@ package cert import ( + "bufio" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func scanAll(t *testing.T, input string) ([]string, error) { + t.Helper() + scanner := bufio.NewScanner(strings.NewReader(input)) + scanner.Split(SplitPEM) + var blocks []string + for scanner.Scan() { + blocks = append(blocks, scanner.Text()) + } + return blocks, scanner.Err() +} + +func TestSplitPEM_Single(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 1) + require.Equal(t, input, blocks[0]) +} + +func TestSplitPEM_Multiple(t *testing.T) { + block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n" + block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n" + blocks, err := scanAll(t, block1+block2) + require.NoError(t, err) + require.Len(t, blocks, 2) + require.Equal(t, block1, blocks[0]) + require.Equal(t, block2, blocks[1]) +} + +func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) { + input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 2) +} + +func TestSplitPEM_Empty(t *testing.T) { + blocks, err := scanAll(t, "") + require.NoError(t, err) + require.Empty(t, blocks) +} + +func TestSplitPEM_WhitespaceOnly(t *testing.T) { + blocks, err := scanAll(t, " \n\t\n ") + require.NoError(t, err) + require.Empty(t, blocks) +} + +func TestSplitPEM_TrailingGarbage(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage" + blocks, err := scanAll(t, input) + require.ErrorIs(t, err, ErrTruncatedPEMBlock) + require.Len(t, blocks, 1) +} + +func TestSplitPEM_TruncatedBlock(t *testing.T) { + input := "-----BEGIN TEST-----\npartial data with no end" + _, err := scanAll(t, input) + require.ErrorIs(t, err, ErrTruncatedPEMBlock) +} + +func TestSplitPEM_NoEndNewline(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 1) + require.Equal(t, input, blocks[0]) +} + +func TestSplitPEM_GarbageOnly(t *testing.T) { + _, err := scanAll(t, "this is not PEM data") + require.ErrorIs(t, err, ErrTruncatedPEMBlock) +} + func TestUnmarshalCertificateFromPEM(t *testing.T) { goodCert := []byte(` # A good cert diff --git a/cert/sign.go b/cert/sign.go index 3eb08592..fbfffe4e 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -9,6 +9,8 @@ import ( "fmt" "net/netip" "time" + + "github.com/slackhq/nebula/cert/p256" ) // TBSCertificate represents a certificate intended to be signed. @@ -126,6 +128,13 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb return nil, err } + if curve == Curve_P256 { + sig, err = p256.Normalize(sig) + if err != nil { + return nil, err + } + } + err = c.setSignature(sig) if err != nil { return nil, err diff --git a/cert/sign_test.go b/cert/sign_test.go index e6f43cdf..bf4c9c0d 100644 --- a/cert/sign_test.go +++ b/cert/sign_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert/p256" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -89,3 +90,48 @@ func TestCertificateV1_SignP256(t *testing.T) { require.NoError(t, err) assert.NotNil(t, uc) } + +func TestCertificate_SignP256_AlwaysNormalized(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: true, + Curve: Curve_P256, + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) + rawPriv := priv.D.FillBytes(make([]byte, 32)) + + for i := 0; i < 1000; i++ { + if i&1 == 1 { + tbs.Version = Version1 + } else { + tbs.Version = Version2 + } + c, err := tbs.Sign(nil, Curve_P256, rawPriv) + require.NoError(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + normie, err := p256.IsNormalized(c.Signature()) + require.NoError(t, err) + assert.True(t, normie) + } +} diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index bea4d1d9..36258dd8 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "os" - "strings" "time" "github.com/slackhq/nebula/cert" @@ -40,21 +39,15 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCACert, err := os.ReadFile(*vf.caPath) + caFile, err := os.Open(*vf.caPath) if err != nil { return fmt.Errorf("error while reading ca: %w", err) } + defer caFile.Close() - caPool := cert.NewCAPool() - for { - rawCACert, err = caPool.AddCAFromPEM(rawCACert) - if err != nil { - return fmt.Errorf("error while adding ca cert to pool: %w", err) - } - - if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { - break - } + caPool, err := cert.NewCAPoolFromPEMReader(caFile) + if err != nil && !errors.Is(err, cert.ErrExpired) { + return fmt.Errorf("error while adding ca cert to pool: %w", err) } rawCert, err := os.ReadFile(*vf.certPath) diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index f555e5f5..1aa5e8e6 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -64,7 +64,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) - require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") + require.ErrorIs(t, err, cert.ErrInvalidPEMBlock) // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) diff --git a/cmd/nebula-service/logs_generic.go b/cmd/nebula-service/logs_generic.go index 3b7cdd1c..cc06b4c5 100644 --- a/cmd/nebula-service/logs_generic.go +++ b/cmd/nebula-service/logs_generic.go @@ -3,8 +3,15 @@ package main -import "github.com/sirupsen/logrus" +import ( + "log/slog" + "os" -func HookLogger(l *logrus.Logger) { - // Do nothing, let the logs flow to stdout/stderr + "github.com/slackhq/nebula/logging" +) + +// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows +// platforms have no special sink to integrate with. +func newPlatformLogger() *slog.Logger { + return logging.NewLogger(os.Stdout) } diff --git a/cmd/nebula-service/logs_windows.go b/cmd/nebula-service/logs_windows.go index af6480ef..ca0a55c5 100644 --- a/cmd/nebula-service/logs_windows.go +++ b/cmd/nebula-service/logs_windows.go @@ -1,54 +1,86 @@ package main import ( - "fmt" - "io/ioutil" - "os" + "context" + "log/slog" + "strings" + "sync" - "github.com/kardianos/service" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/logging" ) -// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer -// logrus output will be discarded -func HookLogger(l *logrus.Logger) { - l.AddHook(newLogHook(logger)) - l.SetOutput(ioutil.Discard) +// newPlatformLogger returns a *slog.Logger that routes every log record +// through the Windows service logger so records end up in the Windows +// Event Log. All the heavy lifting (level management, format swap, +// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler; +// this file only contributes: +// +// - an io.Writer that forwards each formatted line to the service +// logger at the current record's Event Log severity, and +// - a thin severityTag that embeds *logging.Handler and overrides +// only Handle / WithAttrs / WithGroup, so Event Viewer's severity +// column and severity-based filters keep working the way they did +// before the slog migration. +// +// Format (text vs json) is carried by the embedded *logging.Handler, so +// logging.format: json in config still produces JSON lines in Event +// Viewer, same as the pre-slog logrus setup. +func newPlatformLogger() *slog.Logger { + w := &eventLogWriter{} + return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w}) } -type logHook struct { - sl service.Logger +// eventLogWriter forwards slog-formatted lines to the Windows service +// logger at the severity most recently stashed by severityTag.Handle. +// The mutex serializes the stash + inner.Handle + Write cycle per record +// across all concurrent goroutines; slog's builtin text/json handlers +// each hold their own mutex around Write, but that only protects the +// Write call itself, not our stash-then-handle sequence. +type eventLogWriter struct { + mu sync.Mutex + level slog.Level } -func newLogHook(sl service.Logger) *logHook { - return &logHook{sl: sl} -} - -func (h *logHook) Fire(entry *logrus.Entry) error { - line, err := entry.String() - if err != nil { - fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err) - return err - } - - switch entry.Level { - case logrus.PanicLevel: - return h.sl.Error(line) - case logrus.FatalLevel: - return h.sl.Error(line) - case logrus.ErrorLevel: - return h.sl.Error(line) - case logrus.WarnLevel: - return h.sl.Warning(line) - case logrus.InfoLevel: - return h.sl.Info(line) - case logrus.DebugLevel: - return h.sl.Info(line) +func (w *eventLogWriter) Write(p []byte) (int, error) { + line := strings.TrimRight(string(p), "\n") + switch { + case w.level >= slog.LevelError: + return len(p), logger.Error(line) + case w.level >= slog.LevelWarn: + return len(p), logger.Warning(line) default: - return nil + return len(p), logger.Info(line) } } -func (h *logHook) Levels() []logrus.Level { - return logrus.AllLevels +// severityTag embeds *logging.Handler to pick up everything it does for +// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat, +// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup +// so each record's slog.Level is stashed on the writer before formatting +// and so derived handlers stay wrapped as severityTag rather than +// downgrading to bare *logging.Handler. +type severityTag struct { + *logging.Handler + w *eventLogWriter +} + +func (s *severityTag) Handle(ctx context.Context, r slog.Record) error { + s.w.mu.Lock() + defer s.w.mu.Unlock() + s.w.level = r.Level + return s.Handler.Handle(ctx, r) +} + +func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return s + } + return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w} +} + +func (s *severityTag) WithGroup(name string) slog.Handler { + if name == "" { + return s + } + return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w} } diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 9a17b947..19fb3a9f 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -7,9 +7,9 @@ import ( "runtime/debug" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/util" ) @@ -50,9 +50,14 @@ func main() { os.Exit(0) } + l := logging.NewLogger(os.Stdout) + if *serviceFlag != "" { - doService(configPath, configTest, Build, serviceFlag) - os.Exit(1) + if err := doService(configPath, configTest, Build, serviceFlag); err != nil { + l.Error("Service command failed", "error", err) + os.Exit(1) + } + return } if *configPath == "" { @@ -61,9 +66,6 @@ func main() { os.Exit(1) } - l := logrus.New() - l.Out = os.Stdout - c := config.NewC(l) err := c.Load(*configPath) if err != nil { @@ -71,6 +73,16 @@ func main() { os.Exit(1) } + if err := logging.ApplyConfig(l, c); err != nil { + fmt.Printf("failed to apply logging config: %s", err) + os.Exit(1) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) @@ -78,8 +90,20 @@ func main() { } if !*configTest { - ctrl.Start() - ctrl.ShutdownBlock() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() + + if err := wait(); err != nil { + l.Error("Nebula stopped due to fatal error", "error", err) + os.Exit(2) + } + + l.Info("Goodbye") } os.Exit(0) diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index a54fb0f3..6551ceb4 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -7,9 +7,9 @@ import ( "path/filepath" "github.com/kardianos/service" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" ) var logger service.Logger @@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error { // Start should not block. logger.Info("Nebula service starting.") - l := logrus.New() - HookLogger(l) + l := newPlatformLogger() c := config.NewC(l) err := c.Load(*p.configPath) @@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error { return fmt.Errorf("failed to load config: %s", err) } + if err := logging.ApplyConfig(l, c); err != nil { + return fmt.Errorf("failed to apply logging config: %s", err) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + p.control, err = nebula.Main(c, *p.configTest, Build, l, nil) if err != nil { return err @@ -57,11 +65,11 @@ func fileExists(filename string) bool { return true } -func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { +func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error { if *configPath == "" { ex, err := os.Executable() if err != nil { - panic(err) + return err } *configPath = filepath.Dir(ex) + "/config.yaml" if !fileExists(*configPath) { @@ -85,16 +93,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag * // Here are what the different loggers are doing: // - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr // - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log) - // - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use + // - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger s, err := service.New(prg, svcConfig) if err != nil { - log.Fatal(err) + return err } errs := make(chan error, 5) logger, err = s.Logger(errs) if err != nil { - log.Fatal(err) + return err } go func() { @@ -109,18 +117,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag * switch *serviceFlag { case "run": - err = s.Run() - if err != nil { + if err := s.Run(); err != nil { // Route any errors to the system logger logger.Error(err) } default: - err := service.Control(s, *serviceFlag) - if err != nil { + if err := service.Control(s, *serviceFlag); err != nil { log.Printf("Valid actions: %q\n", service.ControlAction) - log.Fatal(err) + return err } - return } + return nil } diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index ffdc15bf..d7f0de93 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -7,9 +7,9 @@ import ( "runtime/debug" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/util" ) @@ -55,8 +55,7 @@ func main() { os.Exit(1) } - l := logrus.New() - l.Out = os.Stdout + l := logging.NewLogger(os.Stdout) c := config.NewC(l) err := c.Load(*configPath) @@ -65,6 +64,16 @@ func main() { os.Exit(1) } + if err := logging.ApplyConfig(l, c); err != nil { + fmt.Printf("failed to apply logging config: %s", err) + os.Exit(1) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) @@ -72,9 +81,21 @@ func main() { } if !*configTest { - ctrl.Start() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() notifyReady(l) - ctrl.ShutdownBlock() + + if err := wait(); err != nil { + l.Error("Nebula stopped due to fatal error", "error", err) + os.Exit(2) + } + + l.Info("Goodbye") } os.Exit(0) diff --git a/cmd/nebula/notify_linux.go b/cmd/nebula/notify_linux.go index 8c3dca55..965986a9 100644 --- a/cmd/nebula/notify_linux.go +++ b/cmd/nebula/notify_linux.go @@ -1,11 +1,10 @@ package main import ( + "log/slog" "net" "os" "time" - - "github.com/sirupsen/logrus" ) // SdNotifyReady tells systemd the service is ready and dependent services can now be started @@ -13,30 +12,30 @@ import ( // https://www.freedesktop.org/software/systemd/man/systemd.service.html const SdNotifyReady = "READY=1" -func notifyReady(l *logrus.Logger) { +func notifyReady(l *slog.Logger) { sockName := os.Getenv("NOTIFY_SOCKET") if sockName == "" { - l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal") + l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal") return } conn, err := net.DialTimeout("unixgram", sockName, time.Second) if err != nil { - l.WithError(err).Error("failed to connect to systemd notification socket") + l.Error("failed to connect to systemd notification socket", "error", err) return } defer conn.Close() err = conn.SetWriteDeadline(time.Now().Add(time.Second)) if err != nil { - l.WithError(err).Error("failed to set the write deadline for the systemd notification socket") + l.Error("failed to set the write deadline for the systemd notification socket", "error", err) return } if _, err = conn.Write([]byte(SdNotifyReady)); err != nil { - l.WithError(err).Error("failed to signal the systemd notification socket") + l.Error("failed to signal the systemd notification socket", "error", err) return } - l.Debugln("notified systemd the service is ready") + l.Debug("notified systemd the service is ready") } diff --git a/cmd/nebula/notify_notlinux.go b/cmd/nebula/notify_notlinux.go index e7758e09..48cfe949 100644 --- a/cmd/nebula/notify_notlinux.go +++ b/cmd/nebula/notify_notlinux.go @@ -3,8 +3,8 @@ package main -import "github.com/sirupsen/logrus" +import "log/slog" -func notifyReady(_ *logrus.Logger) { +func notifyReady(_ *slog.Logger) { // No init service to notify } diff --git a/config/config.go b/config/config.go index 0d1be128..5bf994a1 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log/slog" "math" "os" "os/signal" @@ -16,7 +17,6 @@ import ( "time" "dario.cat/mergo" - "github.com/sirupsen/logrus" "go.yaml.in/yaml/v3" ) @@ -26,11 +26,11 @@ type C struct { Settings map[string]any oldSettings map[string]any callbacks []func(*C) - l *logrus.Logger + l *slog.Logger reloadLock sync.Mutex } -func NewC(l *logrus.Logger) *C { +func NewC(l *slog.Logger) *C { return &C{ Settings: make(map[string]any), l: l, @@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool { newVals, err := yaml.Marshal(nv) if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") + c.l.Error("Error while marshaling new config", + "config_path", k, + "error", err, + ) } oldVals, err := yaml.Marshal(ov) if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") + c.l.Error("Error while marshaling old config", + "config_path", k, + "error", err, + ) } return string(newVals) != string(oldVals) @@ -154,7 +160,10 @@ func (c *C) ReloadConfig() { err := c.Load(c.path) if err != nil { - c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") + c.l.Error("Error occurred while reloading config", + "config_path", c.path, + "error", err, + ) return } diff --git a/connection_manager.go b/connection_manager.go index 4c2f26ef..e7fc04cd 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -5,13 +5,13 @@ import ( "context" "encoding/binary" "fmt" + "log/slog" "net/netip" "sync" "sync/atomic" "time" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -47,10 +47,10 @@ type connectionManager struct { metricsTxPunchy metrics.Counter - l *logrus.Logger + l *slog.Logger } -func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { +func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { cm := &connectionManager{ hostMap: hm, l: l, @@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) { old := cm.getInactivityTimeout() cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute))) if !initial { - cm.l.WithField("oldDuration", old). - WithField("newDuration", cm.getInactivityTimeout()). - Info("Inactivity timeout has changed") + cm.l.Info("Inactivity timeout has changed", + "oldDuration", old, + "newDuration", cm.getInactivityTimeout(), + ) } } @@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) { old := cm.dropInactive.Load() cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false)) if !initial { - cm.l.WithField("oldBool", old). - WithField("newBool", cm.dropInactive.Load()). - Info("Drop inactive setting has changed") + cm.l.Info("Drop inactive setting has changed", + "oldBool", old, + "newBool", cm.dropInactive.Load(), + ) } } } @@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo var err error index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { - cm.l.WithError(err).Error("failed to migrate relay to new hostinfo") + cm.l.Error("failed to migrate relay to new hostinfo", "error", err) continue } switch r.Type { @@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo msg, err := req.Marshal() if err != nil { - cm.l.WithError(err).Error("failed to marshal Control message to migrate relay") + cm.l.Error("failed to marshal Control message to migrate relay", "error", err) } else { cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) - cm.l.WithFields(logrus.Fields{ - "relayFrom": req.RelayFromAddr, - "relayTo": req.RelayToAddr, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnAddrs": newhostinfo.vpnAddrs}). - Info("send CreateRelayRequest") + cm.l.Info("send CreateRelayRequest", + "relayFrom", req.RelayFromAddr, + "relayTo", req.RelayToAddr, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddrs", newhostinfo.vpnAddrs, + ) } } } @@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim hostinfo := cm.hostMap.Indexes[localIndex] if hostinfo == nil { - cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap") + cm.l.Debug("Not found in hostmap", "localIndex", localIndex) return doNothing, nil, nil } @@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim // A hostinfo is determined alive if there is incoming traffic if inTraffic { decision := doNothing - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). - Debug("Tunnel status") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Tunnel status", + "tunnelCheck", m{"state": "alive", "method": "passive"}, + ) } hostinfo.pendingDeletion.Store(false) @@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if hostinfo.pendingDeletion.Load() { // We have already sent a test packet and nothing was returned, this hostinfo is dead - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "dead", "method": "active"}). - Info("Tunnel status") + hostinfo.logger(cm.l).Info("Tunnel status", + "tunnelCheck", m{"state": "dead", "method": "active"}, + ) return deleteTunnel, hostinfo, nil } @@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim inactiveFor, isInactive := cm.isInactive(hostinfo, now) if isInactive { // Tunnel is inactive, tear it down - hostinfo.logger(cm.l). - WithField("inactiveDuration", inactiveFor). - WithField("primary", mainHostInfo). - Info("Dropping tunnel due to inactivity") + hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity", + "inactiveDuration", inactiveFor, + "primary", mainHostInfo, + ) return closeTunnel, hostinfo, primary } @@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim cm.sendPunch(hostinfo) } - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "testing", "method": "active"}). - Debug("Tunnel status") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Tunnel status", + "tunnelCheck", m{"state": "testing", "method": "active"}, + ) } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues decision = sendTestPacket } else { - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l).Debugf("Hostinfo sadness") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Hostinfo sadness") } } @@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI return false //cert is still valid! yay! } else if err == cert.ErrBlockListed { //avoiding errors.Is for speed // Block listed certificates should always be disconnected - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is blocked, tearing down the tunnel") + hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel", + "error", err, + "fingerprint", remoteCert.Fingerprint, + ) return true } else if cm.intf.disconnectInvalid.Load() { - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is no longer valid, tearing down the tunnel") + hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel", + "error", err, + "fingerprint", remoteCert.Fingerprint, + ) return true } else { //if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open @@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { curCrtVersion := curCrt.Version() myCrt := cs.getCertificate(curCrtVersion) if myCrt == nil { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("version", curCrtVersion). - WithField("reason", "local certificate removed"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "version", curCrtVersion, + "reason", "local certificate removed", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } @@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() { // if our certificate version is less than theirs, and we have a matching version available, rehandshake? if cs.getCertificate(peerCrt.Certificate.Version()) != nil { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("version", curCrtVersion). - WithField("peerVersion", peerCrt.Certificate.Version()). - WithField("reason", "local certificate version lower than peer, attempting to correct"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "version", curCrtVersion, + "peerVersion", peerCrt.Certificate.Version(), + "reason", "local certificate version lower than peer, attempting to correct", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) { hh.initiatingVersionOverride = peerCrt.Certificate.Version() }) @@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { } } if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "local certificate is not current"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "reason", "local certificate is not current", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } if curCrtVersion < cs.initiatingVersion { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "current cert version < pki.initiatingVersion"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "reason", "current cert version < pki.initiatingVersion", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return diff --git a/connection_manager_test.go b/connection_manager_test.go index 647dd72b..a015fba9 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/overlaytest" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" @@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -63,9 +64,9 @@ func Test_NewConnectionManagerTest(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) @@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -146,9 +147,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) @@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -231,12 +232,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) conf.Settings["tunnels"] = map[string]any{ "drop_inactive": true, } - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) assert.True(t, nc.dropInactive.Load()) nc.intf = ifce @@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -360,9 +361,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.disconnectInvalid.Store(true) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce ifce.connectionManager = nc diff --git a/connection_state.go b/connection_state.go index db885d42..b85aebd4 100644 --- a/connection_state.go +++ b/connection_state.go @@ -8,7 +8,6 @@ import ( "sync/atomic" "github.com/flynn/noise" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/noiseutil" ) @@ -27,7 +26,7 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { +func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { var dhFunc noise.DHFunc switch crt.Curve() { case cert.Curve_CURVE25519: diff --git a/control.go b/control.go index f8567b50..ef58988b 100644 --- a/control.go +++ b/control.go @@ -2,17 +2,33 @@ package nebula import ( "context" + "errors" + "log/slog" "net/netip" "os" "os/signal" + "sync" "syscall" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" ) +type RunState int + +const ( + StateUnknown RunState = iota + StateReady + StateStarted + StateStopping + StateStopped +) + +var ErrAlreadyStarted = errors.New("nebula is already started") +var ErrAlreadyStopped = errors.New("nebula cannot be restarted") +var ErrUnknownState = errors.New("nebula state is invalid") + // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc @@ -26,8 +42,11 @@ type controlHostLister interface { } type Control struct { + stateLock sync.Mutex + state RunState + f *Interface - l *logrus.Logger + l *slog.Logger ctx context.Context cancel context.CancelFunc sshStart func() @@ -49,10 +68,31 @@ type ControlHostInfo struct { CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } -// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() -func (c *Control) Start() { +// Start actually runs nebula, this is a nonblocking call. +// The returned function blocks until nebula has fully stopped and returns the +// first fatal reader error (if any). A nil error means nebula shut down +// gracefully; a non-nil error means a reader hit an unexpected failure that +// triggered the shutdown. +func (c *Control) Start() (func() error, error) { + c.stateLock.Lock() + defer c.stateLock.Unlock() + switch c.state { + case StateReady: + //yay! + case StateStopped, StateStopping: + return nil, ErrAlreadyStopped + case StateStarted: + return nil, ErrAlreadyStarted + default: + return nil, ErrUnknownState + } + // Activate the interface - c.f.activate() + err := c.f.activate() + if err != nil { + c.state = StateStopped + return nil, err + } // Call all the delayed funcs that waited patiently for the interface to be created. if c.sshStart != nil { @@ -71,25 +111,51 @@ func (c *Control) Start() { c.lighthouseStart() } + c.f.triggerShutdown = c.Stop + // Start reading packets. - c.f.run() + out, err := c.f.run() + if err != nil { + c.state = StateStopped + return nil, err + } + c.state = StateStarted + return out, nil +} + +func (c *Control) State() RunState { + c.stateLock.Lock() + defer c.stateLock.Unlock() + return c.state } func (c *Control) Context() context.Context { return c.ctx } -// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete +// Stop is a non-blocking call that signals nebula to close all tunnels and shut down func (c *Control) Stop() { + c.stateLock.Lock() + if c.state != StateStarted { + c.stateLock.Unlock() + // We are stopping or stopped already + return + } + + c.state = StateStopping + c.stateLock.Unlock() + // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. c.cancel() c.CloseAllTunnels(false) if err := c.f.Close(); err != nil { - c.l.WithError(err).Error("Close interface failed") + c.l.Error("Close interface failed", "error", err) } - c.l.Info("Goodbye") + c.stateLock.Lock() + c.state = StateStopped + c.stateLock.Unlock() } // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled @@ -100,7 +166,7 @@ func (c *Control) ShutdownBlock() { rawSig := <-sigChan sig := rawSig.String() - c.l.WithField("signal", sig).Info("Caught signal, shutting down") + c.l.Info("Caught signal, shutting down", "signal", sig) c.Stop() } @@ -237,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) - c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). - Debug("Sending close tunnel message") + c.l.Debug("Sending close tunnel message", + "vpnAddrs", h.vpnAddrs, + "udpAddr", h.remote, + ) closed++ } diff --git a/control_test.go b/control_test.go index e8a5d312..5e381c46 100644 --- a/control_test.go +++ b/control_test.go @@ -6,7 +6,6 @@ import ( "reflect" "testing" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" @@ -79,10 +78,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, &Interface{}) c := Control{ + state: StateReady, f: &Interface{ hostMap: hm, }, - l: logrus.New(), + l: test.NewLogger(), } thi := c.GetHostInfoByVpnAddr(vpnIp, false) diff --git a/control_tester.go b/control_tester.go index 7403a745..f927140b 100644 --- a/control_tester.go +++ b/control_tester.go @@ -1,5 +1,4 @@ //go:build e2e_testing -// +build e2e_testing package nebula diff --git a/dist/wireshark/nebula.lua b/dist/wireshark/nebula.lua index ddc808f9..d17dc7a0 100644 --- a/dist/wireshark/nebula.lua +++ b/dist/wireshark/nebula.lua @@ -84,30 +84,24 @@ end function nebula.prefs_changed() if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then - -- Nothing changed, bail return end - -- Remove our old dissector + -- Remove all existing registrations DissectorTable.get("udp.port"):remove_all(nebula) - if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then - default_settings.all_port = nebula.prefs.all_ports - + if nebula.prefs.all_ports then + -- Register on every port for hole punch capture for i=0, 65535 do DissectorTable.get("udp.port"):add(i, nebula) end - - -- no need to establish again on specific ports - return + else + -- Register on the configured port only + DissectorTable.get("udp.port"):add(nebula.prefs.port, nebula) end - - if default_settings.all_ports ~= nebula.prefs.all_ports then - -- Add our new port dissector - default_settings.port = nebula.prefs.port - DissectorTable.get("udp.port"):add(default_settings.port, nebula) - end + default_settings.all_ports = nebula.prefs.all_ports + default_settings.port = nebula.prefs.port end DissectorTable.get("udp.port"):add(default_settings.port, nebula) diff --git a/dns_server.go b/dns_server.go index 73576546..ff1369ab 100644 --- a/dns_server.go +++ b/dns_server.go @@ -1,63 +1,249 @@ package nebula import ( + "context" "fmt" + "log/slog" "net" "net/netip" "strconv" "strings" "sync" + "sync/atomic" "github.com/gaissmai/bart" "github.com/miekg/dns" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) -// This whole thing should be rewritten to use context - -var dnsR *dnsRecords -var dnsServer *dns.Server -var dnsAddr string - -type dnsRecords struct { +type dnsServer struct { sync.RWMutex - l *logrus.Logger + l *slog.Logger + ctx context.Context dnsMap4 map[string]netip.Addr dnsMap6 map[string]netip.Addr hostMap *HostMap myVpnAddrsTable *bart.Lite + + mux *dns.ServeMux + + // enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`. + // Start, Add, and reload consult it so callers don't need to know the + // gating rules. When it toggles off via reload, accumulated records are + // cleared so a later re-enable starts with a fresh map populated from + // new handshakes. + enabled atomic.Bool + + serverMu sync.Mutex + server *dns.Server + // started is closed once `server` has finished binding (or after + // ListenAndServe returns on a bind failure). Stop waits on it before + // calling Shutdown to avoid the miekg/dns "server not started" race + // where a Shutdown that arrives before bind completes is silently + // ignored, leaving the listener running forever. + started chan struct{} + addr string } -func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { - return &dnsRecords{ +// newDnsServerFromConfig builds a dnsServer, applies the initial config, and +// registers a reload callback. The reload callback is registered before the +// initial config is applied, so a SIGHUP can later enable, fix, or disable +// DNS even if the initial application failed. +// +// The dnsServer internally gates on `lighthouse.serve_dns && +// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally, +// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel +// watcher that tears the listener down on nebula shutdown. The returned +// pointer is always non-nil, even on error. +func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { + ds := &dnsServer{ l: l, + ctx: ctx, dnsMap4: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr), hostMap: hostMap, myVpnAddrsTable: cs.myVpnAddrsTable, } + ds.mux = dns.NewServeMux() + ds.mux.HandleFunc(".", ds.handleDnsRequest) + + c.RegisterReloadCallback(func(c *config.C) { + if err := ds.reload(c, false); err != nil { + ds.l.Error("Failed to reload DNS responder from config", "error", err) + } + }) + + if err := ds.reload(c, true); err != nil { + return ds, err + } + return ds, nil } -func (d *dnsRecords) Query(q uint16, data string) netip.Addr { +// reload applies the latest config and reconciles the running state with it: +// - enabled toggled on -> spawn a runner +// - enabled toggled off -> stop the runner +// - listen address changed (while running) -> restart on the new address +// - everything else -> no-op +// +// On the initial call it only records configuration; Control.Start is what +// launches the first runner via dnsStart. +func (d *dnsServer) reload(c *config.C, initial bool) error { + wantsDns := c.GetBool("lighthouse.serve_dns", false) + amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) + enabled := wantsDns && amLighthouse + newAddr := getDnsServerAddr(c) + + d.serverMu.Lock() + running := d.server + runningStarted := d.started + sameAddr := d.addr == newAddr + d.addr = newAddr + d.enabled.Store(enabled) + d.serverMu.Unlock() + + if initial { + if wantsDns && !amLighthouse { + d.l.Warn("DNS server refusing to run because this host is not a lighthouse.") + } + return nil + } + + if !enabled { + if running != nil { + d.Stop() + } + // Drop any records that accumulated while enabled; a later re-enable + // will repopulate from fresh handshakes. + d.clearRecords() + return nil + } + + if running == nil { + // Was disabled (or never started); bring it up now. + go d.Start() + return nil + } + + if sameAddr { + return nil + } + + d.shutdownServer(running, runningStarted, "reload") + // Old Start goroutine has now exited; bring up a fresh listener on the + // new address. + go d.Start() + return nil +} + +// shutdownServer waits for the server to finish binding (so Shutdown actually +// stops it rather than no-oping) and then shuts it down. +func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) { + if srv == nil { + return + } + if started != nil { + <-started + } + if err := srv.Shutdown(); err != nil { + d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err) + } +} + +// Start binds and serves the DNS responder. Blocks until Stop is called or +// the listener errors. Safe to call when DNS is disabled (returns +// immediately). This is what Control.dnsStart points at. +// +// Must be invoked after the tun device is active so that lighthouse.dns.host +// may bind to a nebula IP. +func (d *dnsServer) Start() { + if !d.enabled.Load() { + return + } + + started := make(chan struct{}) + d.serverMu.Lock() + if d.ctx.Err() != nil { + d.serverMu.Unlock() + return + } + addr := d.addr + server := &dns.Server{ + Addr: addr, + Net: "udp", + Handler: d.mux, + NotifyStartedFunc: func() { close(started) }, + } + d.server = server + d.started = started + d.serverMu.Unlock() + + // Per-invocation ctx watcher. Exits when Start does, so we don't leak a + // watcher per reload-driven restart. + done := make(chan struct{}) + go func() { + select { + case <-d.ctx.Done(): + d.shutdownServer(server, started, "shutdown") + case <-done: + } + }() + + d.l.Info("Starting DNS responder", "dnsListener", addr) + err := server.ListenAndServe() + close(done) + + // If the listener never bound (bind error) NotifyStartedFunc never fires, + // so close started here to release any Stop caller waiting on it. + select { + case <-started: + default: + close(started) + } + + if err != nil { + d.l.Warn("Failed to run the DNS responder", "error", err) + } +} + +// Stop shuts down the active server, if any. Idempotent. +func (d *dnsServer) Stop() { + d.serverMu.Lock() + srv := d.server + started := d.started + d.server = nil + d.started = nil + d.serverMu.Unlock() + d.shutdownServer(srv, started, "stop") +} + +// Query returns the address for the given name and query type. The second +// return value reports whether the name is known at all (in either A or AAAA), +// which lets callers distinguish NODATA from NXDOMAIN. +func (d *dnsServer) Query(q uint16, data string) (netip.Addr, bool) { data = strings.ToLower(data) d.RLock() defer d.RUnlock() + addr4, haveV4 := d.dnsMap4[data] + addr6, haveV6 := d.dnsMap6[data] + nameExists := haveV4 || haveV6 switch q { case dns.TypeA: - if r, ok := d.dnsMap4[data]; ok { - return r + if haveV4 { + return addr4, nameExists } case dns.TypeAAAA: - if r, ok := d.dnsMap6[data]; ok { - return r + if haveV6 { + return addr6, nameExists } } - return netip.Addr{} + return netip.Addr{}, nameExists } -func (d *dnsRecords) QueryCert(data string) string { +func (d *dnsServer) QueryCert(data string) string { + if len(data) < 2 { + return "" + } ip, err := netip.ParseAddr(data[:len(data)-1]) if err != nil { return "" @@ -80,8 +266,19 @@ func (d *dnsRecords) QueryCert(data string) string { return string(b) } +// clearRecords drops all DNS records. +func (d *dnsServer) clearRecords() { + d.Lock() + defer d.Unlock() + clear(d.dnsMap4) + clear(d.dnsMap6) +} + // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` -func (d *dnsRecords) Add(host string, addresses []netip.Addr) { +func (d *dnsServer) Add(host string, addresses []netip.Addr) { + if !d.enabled.Load() { + return + } host = strings.ToLower(host) d.Lock() defer d.Unlock() @@ -101,7 +298,7 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) { } } -func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { +func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { a, _, _ := net.SplitHostPort(addr) b, err := netip.ParseAddr(a) if err != nil { @@ -116,13 +313,24 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { return d.myVpnAddrsTable.Contains(b) } -func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { +func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { + debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug) + // Per RFC 2308 §2.2, a name that exists but has no record of the requested + // type must be answered with NOERROR and an empty answer section (NODATA), + // not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not + // exist at all. + anyNameExists := false for _, q := range m.Question { switch q.Qtype { case dns.TypeA, dns.TypeAAAA: qType := dns.TypeToString[q.Qtype] - d.l.Debugf("Query for %s %s", qType, q.Name) - ip := d.Query(q.Qtype, q.Name) + if debugEnabled { + d.l.Debug("DNS query", "type", qType, "name", q.Name) + } + ip, nameExists := d.Query(q.Qtype, q.Name) + if nameExists { + anyNameExists = true + } if ip.IsValid() { rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { @@ -134,7 +342,9 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } - d.l.Debugf("Query for TXT %s", q.Name) + if debugEnabled { + d.l.Debug("DNS query", "type", "TXT", "name", q.Name) + } ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) @@ -145,12 +355,12 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { } } - if len(m.Answer) == 0 { + if len(m.Answer) == 0 && !anyNameExists { m.Rcode = dns.RcodeNameError } } -func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { +func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false @@ -163,21 +373,6 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(l, cs, hostMap) - - // attach request handler func - dns.HandleFunc(".", dnsR.handleDnsRequest) - - c.RegisterReloadCallback(func(c *config.C) { - reloadDns(l, c) - }) - - return func() { - startDns(l, c) - } -} - func getDnsServerAddr(c *config.C) string { dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", "")) // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve. @@ -186,25 +381,3 @@ func getDnsServerAddr(c *config.C) string { } return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } - -func startDns(l *logrus.Logger, c *config.C) { - dnsAddr = getDnsServerAddr(c) - dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} - l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") - err := dnsServer.ListenAndServe() - defer dnsServer.Shutdown() - if err != nil { - l.Errorf("Failed to start server: %s\n ", err.Error()) - } -} - -func reloadDns(l *logrus.Logger, c *config.C) { - if dnsAddr == getDnsServerAddr(c) { - l.Debug("No DNS server config change detected") - return - } - - l.Debug("Restarting DNS server") - dnsServer.Shutdown() - go startDns(l, c) -} diff --git a/dns_server_test.go b/dns_server_test.go index 356e5890..dcea046c 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -1,19 +1,43 @@ package nebula import ( + "context" + "log/slog" + "net" "net/netip" + "strconv" "testing" + "time" "github.com/miekg/dns" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +type stubDNSWriter struct{} + +func (stubDNSWriter) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (stubDNSWriter) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5353} +} +func (stubDNSWriter) Write([]byte) (int, error) { return 0, nil } +func (stubDNSWriter) WriteMsg(*dns.Msg) error { return nil } +func (stubDNSWriter) Close() error { return nil } +func (stubDNSWriter) TsigStatus() error { return nil } +func (stubDNSWriter) TsigTimersOnly(bool) {} +func (stubDNSWriter) Hijack() {} + func TestParsequery(t *testing.T) { - l := logrus.New() + l := slog.New(slog.DiscardHandler) hostMap := &HostMap{} - ds := newDnsRecords(l, &CertState{}, hostMap) + ds := &dnsServer{ + l: l, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + } + ds.enabled.Store(true) addrs := []netip.Addr{ netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5"), @@ -21,18 +45,56 @@ func TestParsequery(t *testing.T) { netip.MustParseAddr("fd01::25"), } ds.Add("test.com.com", addrs) + ds.Add("v4only.com.com", []netip.Addr{netip.MustParseAddr("1.2.3.6")}) + ds.Add("v6only.com.com", []netip.Addr{netip.MustParseAddr("fd01::26")}) m := &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeA) ds.parseQuery(m, nil) assert.NotNil(t, m.Answer) assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) m = &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeAAAA) ds.parseQuery(m, nil) assert.NotNil(t, m.Answer) assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) + + // A known name with no record of the requested type should return NODATA + // (NOERROR with empty answer), not NXDOMAIN. + m = &dns.Msg{} + m.SetQuestion("v4only.com.com", dns.TypeAAAA) + ds.parseQuery(m, nil) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) + + m = &dns.Msg{} + m.SetQuestion("v6only.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeSuccess, m.Rcode) + + // An unknown name should still return NXDOMAIN. + m = &dns.Msg{} + m.SetQuestion("unknown.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeNameError, m.Rcode) + + // short lookups should not fail + m = &dns.Msg{} + m.Question = []dns.Question{{Name: "", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}} + ds.parseQuery(m, stubDNSWriter{}) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeNameError, m.Rcode) + + m = &dns.Msg{} + m.Question = []dns.Question{{Name: ".", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}} + ds.parseQuery(m, stubDNSWriter{}) + assert.Empty(t, m.Answer) + assert.Equal(t, dns.RcodeNameError, m.Rcode) } func Test_getDnsServerAddr(t *testing.T) { @@ -71,3 +133,208 @@ func Test_getDnsServerAddr(t *testing.T) { } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) } + +func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) { + t.Helper() + sl := slog.New(slog.DiscardHandler) + ds := &dnsServer{ + l: sl, + ctx: context.Background(), + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: &HostMap{}, + } + ds.mux = dns.NewServeMux() + ds.mux.HandleFunc(".", ds.handleDnsRequest) + return ds, config.NewC(nil) +} + +func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) { + c.Settings["lighthouse"] = map[string]any{ + "am_lighthouse": amLighthouse, + "serve_dns": serveDns, + "dns": map[string]any{ + "host": host, + "port": port, + }, + } +} + +func TestDnsServer_reload_initial_disabled(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, false) + + require.NoError(t, ds.reload(c, true)) + assert.False(t, ds.enabled.Load()) + assert.Equal(t, "127.0.0.1:0", ds.addr) + assert.Nil(t, ds.server) +} + +func TestDnsServer_reload_initial_enabled(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, true) + + require.NoError(t, ds.reload(c, true)) + assert.True(t, ds.enabled.Load()) + assert.Equal(t, "127.0.0.1:0", ds.addr) + // initial never starts a runner; that's Control.Start's job + assert.Nil(t, ds.server) +} + +func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", false, true) + + require.NoError(t, ds.reload(c, true)) + // Wants DNS but isn't a lighthouse: gated off, no runner. + assert.False(t, ds.enabled.Load()) +} + +func TestDnsServer_reload_sameAddr_noOp(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, true) + + require.NoError(t, ds.reload(c, true)) + // No server running yet, no addr change. Reload should not spawn anything. + require.NoError(t, ds.reload(c, false)) + assert.True(t, ds.enabled.Load()) + assert.Nil(t, ds.server) +} + +func TestDnsServer_StartStop_lifecycle(t *testing.T) { + // Bind to a real (random) UDP port so we exercise the actual + // ListenAndServe + Shutdown plumbing including the started-chan race fix. + port := freeUDPPort(t) + + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", port, true, true) + require.NoError(t, ds.reload(c, true)) + + done := make(chan struct{}) + go func() { + ds.Start() + close(done) + }() + + waitFor(t, func() bool { + ds.serverMu.Lock() + started := ds.started + ds.serverMu.Unlock() + if started == nil { + return false + } + select { + case <-started: + return true + default: + return false + } + }) + + ds.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } +} + +func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) { + // Stop called immediately after Start should not deadlock even if bind + // hasn't completed yet. This exercises the started-chan close-on-bind-fail + // path: by binding to an obviously bad port (privileged) we get a fast + // bind error before NotifyStartedFunc fires. + ds, c := newTestDnsServer(t) + // Use a port that should fail to bind (negative would be invalid, use a + // host that won't resolve to ensure listenUDP fails quickly). + setDnsConfig(c, "256.256.256.256", "53", true, true) + require.NoError(t, ds.reload(c, true)) + + done := make(chan struct{}) + go func() { + ds.Start() + close(done) + }() + + // Give Start a moment to attempt the bind and fail. + select { + case <-done: + // Bind failed and Start returned; Stop should be a no-op. + case <-time.After(time.Second): + t.Fatal("Start did not return after a bad bind") + } + + stopped := make(chan struct{}) + go func() { + ds.Stop() + close(stopped) + }() + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatal("Stop hung after a failed bind") + } +} + +func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) { + port := freeUDPPort(t) + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", port, true, true) + require.NoError(t, ds.reload(c, true)) + + startReturned := make(chan struct{}) + go func() { + ds.Start() + close(startReturned) + }() + waitForBind(t, ds) + + // Toggle serve_dns off; reload should shut the running server down. + setDnsConfig(c, "127.0.0.1", port, true, false) + require.NoError(t, ds.reload(c, false)) + select { + case <-startReturned: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after reload disabled DNS") + } + assert.False(t, ds.enabled.Load()) +} + +func freeUDPPort(t *testing.T) string { + t.Helper() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + port := conn.LocalAddr().(*net.UDPAddr).Port + require.NoError(t, conn.Close()) + return strconv.Itoa(port) +} + +func waitForBind(t *testing.T, ds *dnsServer) { + t.Helper() + waitFor(t, func() bool { + ds.serverMu.Lock() + started := ds.started + ds.serverMu.Unlock() + if started == nil { + return false + } + select { + case <-started: + return true + default: + return false + } + }) +} + +func waitFor(t *testing.T, cond func() bool) { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatal("timed out waiting for condition") +} diff --git a/e2e/handshake_manager_test.go b/e2e/handshake_manager_test.go new file mode 100644 index 00000000..3fe784c1 --- /dev/null +++ b/e2e/handshake_manager_test.go @@ -0,0 +1,565 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" + "github.com/stretchr/testify/assert" +) + +// makeHandshakePacket creates a handshake packet with the given parameters. +func makeHandshakePacket(from, to netip.AddrPort, subtype header.MessageSubType, remoteIndex uint32, counter uint64) *udp.Packet { + data := make([]byte, 200) + header.Encode(data, header.Version, header.Handshake, subtype, remoteIndex, counter) + for i := header.Len; i < len(data); i++ { + data[i] = byte(i) + } + return &udp.Packet{To: to, From: from, Data: data} +} + +func TestHandshakeRetransmitDuplicate(t *testing.T) { + // Verify the responder correctly handles receiving the same msg1 multiple times + // (retransmission). The duplicate goes through CheckAndComplete -> ErrAlreadySeen + // and the cached response is resent. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Trigger handshake from me to them") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + + t.Log("Grab my msg1") + msg1 := myControl.GetFromUDP(true) + + t.Log("Inject msg1 into them, first time") + theirControl.InjectUDPPacket(msg1) + _ = theirControl.GetFromUDP(true) + + t.Log("Inject the SAME msg1 again, tests ErrAlreadySeen path") + theirControl.InjectUDPPacket(msg1) + resp2 := theirControl.GetFromUDP(true) + assert.NotNil(t, resp2, "should get cached response on duplicate msg1") + + t.Log("Complete handshake with cached response") + myControl.InjectUDPPacket(resp2) + myControl.WaitForType(1, 0, theirControl) + + t.Log("Drain cached packet and verify tunnel works") + cachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Verify only one tunnel exists on each side") + assert.Len(t, myControl.ListHostmapHosts(false), 1) + assert.Len(t, theirControl.ListHostmapHosts(false), 1) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeTruncatedPacketRecovery(t *testing.T) { + // Verify that a truncated handshake packet is ignored and the real + // packet can still complete the handshake. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Trigger handshake") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + + t.Log("Get msg1 and deliver to responder") + msg1 := myControl.GetFromUDP(true) + theirControl.InjectUDPPacket(msg1) + + t.Log("Get the real response") + realResp := theirControl.GetFromUDP(true) + + t.Log("Truncate the response and inject, should be ignored") + truncResp := realResp.Copy() + truncResp.Data = truncResp.Data[:header.Len] + myControl.InjectUDPPacket(truncResp) + + t.Log("Verify pending handshake survived the truncated packet") + assert.NotEmpty(t, myControl.ListHostmapHosts(true), "pending handshake should still exist") + + t.Log("Inject real response, should complete handshake") + myControl.InjectUDPPacket(realResp) + myControl.WaitForType(1, 0, theirControl) + + t.Log("Drain and verify tunnel") + cachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeOrphanedMsg2Dropped(t *testing.T) { + // A msg2 arriving with no matching pending index should be silently dropped + // with no response sent and no state changes. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Complete a normal handshake") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + r.RouteForAllUntilTxTun(theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Record hostmap state") + myIndexes := len(myControl.ListHostmapIndexes(false)) + + t.Log("Inject a fake msg2 with unknown RemoteIndex") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0xDEADBEEF, 2)) + + t.Log("Verify no new indexes created") + assert.Equal(t, myIndexes, len(myControl.ListHostmapIndexes(false))) + + t.Log("Verify no UDP response was sent") + time.Sleep(100 * time.Millisecond) + assert.Nil(t, myControl.GetFromUDP(false), "should not send a response to orphaned msg2") + + t.Log("Verify existing tunnel still works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeUnknownMessageCounter(t *testing.T) { + // A handshake packet with an unexpected message counter should be silently + // dropped with no side effects and no UDP response. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + myControl.Start() + theirControl.Start() + + t.Log("Inject handshake with MessageCounter=3") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 3)) + + t.Log("Inject handshake with MessageCounter=99") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 99)) + + t.Log("Verify no tunnels or pending handshakes") + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + + t.Log("Verify no UDP response was sent") + time.Sleep(100 * time.Millisecond) + assert.Nil(t, myControl.GetFromUDP(false)) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeUnknownSubtype(t *testing.T) { + // A handshake packet with an unknown subtype should be silently dropped. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.Start() + theirControl.Start() + + t.Log("Inject handshake with unknown subtype 99") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.MessageSubType(99), 0, 1)) + + t.Log("Verify no tunnels or pending handshakes") + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + + t.Log("Verify no UDP response was sent") + time.Sleep(100 * time.Millisecond) + assert.Nil(t, myControl.GetFromUDP(false)) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeLateResponse(t *testing.T) { + // After a handshake times out, a late response should be silently ignored + // with no new tunnels created. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{ + "handshakes": m{ + "try_interval": "200ms", + "retries": 2, + }, + }) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + myControl.Start() + theirControl.Start() + + t.Log("Trigger handshake from me") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + + t.Log("Grab msg1 but don't deliver") + msg1 := myControl.GetFromUDP(true) + + t.Log("Wait for handshake to time out") + for i := 0; i < 5; i++ { + time.Sleep(300 * time.Millisecond) + myControl.GetFromUDP(false) + } + + t.Log("Confirm no pending handshakes remain") + assert.Empty(t, myControl.ListHostmapHosts(true)) + + t.Log("Deliver old msg1 to them, they create a tunnel") + theirControl.InjectUDPPacket(msg1) + resp := theirControl.GetFromUDP(true) + assert.NotNil(t, resp) + + t.Log("Inject late response into me, should be ignored") + myControl.InjectUDPPacket(resp) + + t.Log("No tunnel should exist on my side") + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeSelfConnectionRejected(t *testing.T) { + // Verify that a node rejects a handshake containing its own VPN IP in the + // peer cert. We do this by sending the initiator's own msg1 back to itself. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + + // Need a lighthouse entry to trigger a handshake + myControl.InjectLightHouseAddr(netip.MustParseAddr("10.128.0.2"), netip.MustParseAddrPort("10.0.0.2:4242")) + + myControl.Start() + + t.Log("Trigger handshake from me") + myControl.InjectTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + msg1 := myControl.GetFromUDP(true) + + t.Log("Drain any handshake retransmits before injecting") + time.Sleep(100 * time.Millisecond) + for myControl.GetFromUDP(false) != nil { + } + + t.Log("Feed my own msg1 back to me as if it came from someone else") + selfMsg := msg1.Copy() + selfMsg.From = netip.MustParseAddrPort("10.0.0.99:4242") + selfMsg.To = myUdpAddr + myControl.InjectUDPPacket(selfMsg) + + t.Log("Verify no response was sent (self-connection rejected)") + time.Sleep(100 * time.Millisecond) + // Drain any further retransmits from the original handshake, then check + // that none of them are a handshake response (MessageCounter=2) + h := &header.H{} + for { + p := myControl.GetFromUDP(false) + if p == nil { + break + } + _ = h.Parse(p.Data) + assert.NotEqual(t, uint64(2), h.MessageCounter, + "should not send a stage 2 response to self-connection") + } + + t.Log("Verify no tunnel to myself was created") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)) + + myControl.Stop() +} + +func TestHandshakeMessageCounter0Dropped(t *testing.T) { + // MessageCounter=0 is not a valid handshake message and should be dropped. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + _, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.Start() + + t.Log("Inject handshake with MessageCounter=0") + myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 0)) + + time.Sleep(100 * time.Millisecond) + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + assert.Nil(t, myControl.GetFromUDP(false)) + + myControl.Stop() +} + +func TestHandshakeRemoteAllowList(t *testing.T) { + // Verify that a handshake from a blocked underlay IP is dropped with no + // response and no state changes. Then verify the same packet from an + // allowed IP succeeds. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{ + "lighthouse": m{ + "remote_allow_list": m{ + "10.0.0.0/8": true, + "0.0.0.0/0": false, + }, + }, + }) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Trigger handshake from them") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi")) + msg1 := theirControl.GetFromUDP(true) + + t.Log("Rewrite the source to a blocked IP and inject") + blockedMsg := msg1.Copy() + blockedMsg.From = netip.MustParseAddrPort("192.168.1.1:4242") + myControl.InjectUDPPacket(blockedMsg) + + t.Log("Verify no tunnel, no pending, no response from blocked source") + time.Sleep(100 * time.Millisecond) + assert.Empty(t, myControl.ListHostmapHosts(false)) + assert.Empty(t, myControl.ListHostmapHosts(true)) + assert.Nil(t, myControl.GetFromUDP(false), "should not respond to blocked source") + + t.Log("Now inject the real packet from the allowed source") + myControl.InjectUDPPacket(msg1) + + t.Log("Verify handshake completes from allowed source") + resp := myControl.GetFromUDP(true) + assert.NotNil(t, resp) + theirControl.InjectUDPPacket(resp) + theirControl.WaitForType(1, 0, myControl) + + t.Log("Drain cached packet and verify tunnel works") + cachedPacket := myControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi"), cachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { + // When a duplicate msg1 arrives via ErrAlreadySeen, verify the tunnel + // remains functional and hostmap index count is stable. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Complete a normal handshake via the router") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + r.RouteForAllUntilTxTun(theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Record hostmap state") + theirIndexes := len(theirControl.ListHostmapIndexes(false)) + hi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, hi) + originalRemote := hi.CurrentRemote + + t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam")) + r.RouteForAllUntilTxTun(theirControl) + + t.Log("Verify tunnel still works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Verify remote is still valid and index count is stable") + hi2 := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, hi2) + assert.Equal(t, originalRemote, hi2.CurrentRemote) + assert.Equal(t, theirIndexes, len(theirControl.ListHostmapIndexes(false)), + "no extra indexes should be created from ErrAlreadySeen") + + myControl.Stop() + theirControl.Stop() +} + +func TestHandshakeWrongResponderPacketStore(t *testing.T) { + // Verify that when the wrong host responds, the cached packets are + // transferred to the new handshake, the evil tunnel is closed, evil's + // address is blocked, and the correct tunnel is eventually established. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIpNet, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) + + r := router.NewR(t, myControl, theirControl, evilControl) + defer r.RenderFlow() + + myControl.Start() + theirControl.Start() + evilControl.Start() + + t.Log("Send multiple packets to them (cached during handshake)") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2")) + + t.Log("Route until evil tunnel is closed") + h := &header.H{} + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + if err := h.Parse(p.Data); err != nil { + panic(err) + } + if h.Type == header.CloseTunnel && p.To == evilUdpAddr { + return router.RouteAndExit + } + return router.KeepRouting + }) + + t.Log("Verify evil's address is blocked in the new pending handshake") + pendingHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) + if pendingHI != nil { + assert.NotContains(t, pendingHI.RemoteAddrs, evilUdpAddr, + "evil's address should be blocked") + } + + t.Log("Inject correct lighthouse addr for them") + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + t.Log("Route until cached packets arrive at the real them") + p := r.RouteForAllUntilTxTun(theirControl) + assert.NotNil(t, p, "a cached packet should be delivered to the correct host") + + t.Log("Verify the correct host has a tunnel") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + + t.Log("Verify no hostinfo artifacts from evil remain") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), true), + "no pending hostinfo for evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), false), + "no main hostinfo for evil") + + myControl.Stop() + theirControl.Stop() + evilControl.Stop() +} + +func TestHandshakeRelayComplete(t *testing.T) { + // Verify that a relay handshake completes correctly and relay state is + // properly maintained on all three nodes. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger handshake via relay") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay")) + + p := r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + + t.Log("Verify bidirectional tunnel via relay") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Verify relay state on my side shows relay-to-me") + myHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + assert.NotNil(t, myHI) + assert.NotEmpty(t, myHI.CurrentRelaysToMe, "should have relay-to-me for them") + + t.Log("Verify relay state on their side shows relay-to-me") + theirHI := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, theirHI) + assert.NotEmpty(t, theirHI.CurrentRelaysToMe, "should have relay-to-me for me") + + t.Log("Verify relay node shows through-me relays") + relayHI := relayControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.NotNil(t, relayHI) + + myControl.Stop() + relayControl.Stop() + theirControl.Stop() +} + +// NOTE: Relay V1 cert + IPv6 rejection is not tested here because +// InjectTunUDPPacket from a V4 node to a V6 address panics in the test +// framework. The check is in handshake_manager.go handleOutbound relay +// logic (lines ~304-313): if the relay host has a V1 cert and either +// address is IPv6, the relay is skipped. + +// NOTE: Relay reestablishment (Disestablished state transition) is covered +// by the existing TestReestablishRelays in handshakes_test.go. diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 67b166b1..93f200ac 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -11,7 +11,6 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" @@ -749,7 +748,6 @@ func TestStage1RaceRelays2(t *testing.T) { myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) - l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) @@ -771,49 +769,41 @@ func TestStage1RaceRelays2(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - l.Info("Get a tunnel between me and relay") assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - l.Info("Get a tunnel between them and relay") assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - l.Info("Trigger a handshake from both them and me via relay to them and me") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) - r.Log("Wait for a packet from them to me") - l.Info("Wait for a packet from them to me; myControl") + r.Log("Wait for a packet from them to me; myControl") r.RouteForAllUntilTxTun(myControl) - l.Info("Wait for a packet from them to me; theirControl") + r.Log("Wait for a packet from them to me; theirControl") r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") - l.Info("Wait until we remove extra tunnels") - l.WithFields( - logrus.Fields{ - "myControl": len(myControl.GetHostmap().Indexes), - "theirControl": len(theirControl.GetHostmap().Indexes), - "relayControl": len(relayControl.GetHostmap().Indexes), - }).Info("Waiting for hostinfos to be removed...") + t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d", + len(myControl.GetHostmap().Indexes), + len(theirControl.GetHostmap().Indexes), + len(relayControl.GetHostmap().Indexes), + ) hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) retries := 60 for hostInfos > 6 && retries > 0 { hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) - l.WithFields( - logrus.Fields{ - "myControl": len(myControl.GetHostmap().Indexes), - "theirControl": len(theirControl.GetHostmap().Indexes), - "relayControl": len(relayControl.GetHostmap().Indexes), - }).Info("Waiting for hostinfos to be removed...") + t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d", + len(myControl.GetHostmap().Indexes), + len(theirControl.GetHostmap().Indexes), + len(relayControl.GetHostmap().Indexes), + ) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) @@ -821,7 +811,6 @@ func TestStage1RaceRelays2(t *testing.T) { } r.Log("Assert the tunnel works") - l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() @@ -1369,6 +1358,81 @@ func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { theirControl.Stop() } +func TestLighthouseUpdateOnReload(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + // Create the lighthouse + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh", "10.128.0.1/24", m{"lighthouse": m{"am_lighthouse": true}}) + + // Create a client with NO lighthouse configured and a long update interval. + // The initial SendUpdate at startup will be a no-op since no lighthouses are known. + myControl, myVpnIpNet, _, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.2/24", m{ + "lighthouse": m{ + "interval": 600, + "local_allow_list": m{ + "10.0.0.0/24": true, + "::/0": false, + }, + }, + }) + + r := router.NewR(t, lhControl, myControl) + defer r.RenderFlow() + + lhControl.Start() + myControl.Start() + + // Drain any startup packets (there should be none meaningful) + r.FlushAll() + + // Verify lighthouse has no knowledge of the client + assert.Nil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr())) + + // Build a new config that adds the lighthouse + newSettings := make(m) + for k, v := range myConfig.Settings { + newSettings[k] = v + } + newSettings["static_host_map"] = m{ + lhVpnIpNet[0].Addr().String(): []any{lhUdpAddr.String()}, + } + newSettings["lighthouse"] = m{ + "hosts": []any{lhVpnIpNet[0].Addr().String()}, + "interval": 600, + "local_allow_list": m{ + "10.0.0.0/24": true, + "::/0": false, + }, + } + newCfg, err := yaml.Marshal(newSettings) + require.NoError(t, err) + + // Reload the config. The lighthouse.hosts change triggers TriggerUpdate, + // which wakes the update worker. It calls SendUpdate, initiating a + // handshake to the new lighthouse and caching the HostUpdateNotification. + require.NoError(t, myConfig.ReloadConfigString(string(newCfg))) + + // Route until the lighthouse receives the HostUpdateNotification. + // This covers: handshake stage 1, stage 2, then the cached update. + done := make(chan struct{}) + go func() { + r.RouteForAllUntilAfterMsgTypeTo(lhControl, header.LightHouse, 0) + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for lighthouse update after config reload") + } + + // Verify lighthouse now has the client's addresses + assert.NotNil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr())) + + r.RenderHostmaps("Final hostmaps", lhControl, myControl) + lhControl.Stop() + myControl.Stop() +} + func TestGoodHandshakeUnsafeDest(t *testing.T) { unsafePrefix := "192.168.6.0/24" ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 39843efe..381ae897 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -4,7 +4,6 @@ package e2e import ( - "fmt" "io" "net/netip" "os" @@ -12,15 +11,18 @@ import ( "testing" "time" + "log/slog" + "dario.cat/mergo" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.yaml.in/yaml/v3" @@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific "port": udpAddr.Port(), }, "logging": m{ - "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), - "level": l.Level.String(), + "level": testLogLevelName(), }, "timers": m{ "pending_deletion_interval": 2, @@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o "port": udpAddr.Port(), }, "logging": m{ - "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()), - "level": l.Level.String(), + "level": testLogLevelName(), }, "timers": m{ "pending_deletion_interval": 2, @@ -379,24 +379,32 @@ func getAddrs(ns []netip.Prefix) []netip.Addr { return a } -func NewTestLogger() *logrus.Logger { - l := logrus.New() - +func NewTestLogger() *slog.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(io.Discard) - l.SetLevel(logrus.PanicLevel) - return l + return slog.New(slog.NewTextHandler(io.Discard, nil)) } + level := slog.LevelInfo switch v { case "2": - l.SetLevel(logrus.DebugLevel) + level = slog.LevelDebug case "3": - l.SetLevel(logrus.TraceLevel) - default: - l.SetLevel(logrus.InfoLevel) + level = logging.LevelTrace } - - return l + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} + +// testLogLevelName returns the level name string accepted by logging.ApplyConfig +// for the current TEST_LOGS setting. Kept in sync with NewTestLogger. +func testLogLevelName() string { + switch os.Getenv("TEST_LOGS") { + case "2": + return "debug" + case "3": + return "trace" + case "": + return "info" + } + return "info" } diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index e89cf869..e8e41945 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -12,6 +12,8 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v3" ) @@ -365,3 +367,106 @@ func TestCrossStackRelaysWork(t *testing.T) { //theirControl.Stop() //relayControl.Stop() } + +func TestCloseTunnelAuthenticated(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + + r.Log("Assert the tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + r.Log("Close the tunnel") + myControl.CloseTunnel(theirVpnIpNet[0].Addr(), false) + r.FlushAll() + + waitStart := time.Now() + for { + myIndexes := len(myControl.GetHostmap().Indexes) + theirIndexes := len(theirControl.GetHostmap().Indexes) + if myIndexes == 0 && theirIndexes == 0 { + break + } + + since := time.Since(waitStart) + r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since) + if since > time.Second*6 { + t.Fatal("Tunnel should have been declared inactive after 2 seconds and before 6 seconds") + } + + time.Sleep(1 * time.Second) + //r.FlushAll() + } + + r.Logf("Happy path success, tunnels were dropped within %v", time.Since(waitStart)) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + r.Log("Assert another tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + hi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + if hi == nil { + t.Fatal("There is no hostinfo for this tunnel") + } + myHi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + if myHi == nil { + t.Fatal("There is no hostinfo for my tunnel") + } + r.Log("It does") + + buf := make([]byte, 1024) + hdr := header.H{ + Version: 1, + Type: header.CloseTunnel, + Subtype: 0, + Reserved: 0, + RemoteIndex: hi.RemoteIndex, + MessageCounter: 5, + } + out, err := hdr.Encode(buf) + if err != nil { + t.Fatal(err) + } + + pkt := &udp.Packet{ + To: hi.CurrentRemote, + From: myHi.CurrentRemote, + Data: out, + } + r.InjectUDPPacket(myControl, theirControl, pkt) + r.Log("Injected bogus close tunnel. Let's see!") + waitStart = time.Now() + for { + myIndexes := len(myControl.GetHostmap().Indexes) + theirIndexes := len(theirControl.GetHostmap().Indexes) + if myIndexes == 0 { + t.Fatal("myIndexes should not be 0") + } + if theirIndexes == 0 { + t.Fatal("theirIndexes should not be 0, they should have rejected this bogus packet") + } + + since := time.Since(waitStart) + r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since) + if since > time.Second*4 { + t.Log("The tunnel would have been gone by now") + break + } + + time.Sleep(1 * time.Second) + r.FlushAll() + } + + myControl.Stop() + theirControl.Stop() +} diff --git a/examples/config.yml b/examples/config.yml index f81baab6..f5752ae4 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -204,6 +204,12 @@ punchy: # Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access. #trusted_cas: #- "ssh public key string" + # sandbox_dir restricts file paths for profiling commands (start-cpu-profile, save-heap-profile, + # save-mutex-profile) to the specified directory. Relative paths will be resolved within this directory, + # and absolute paths outside of it will be rejected. Default is $TMP/nebula-debug. + # The directory is NOT automatically created. + # Overriding this to "" is the same as "/" and will allow overwriting any path on the host. + #sandbox_dir: /var/tmp/nebula-debug # EXPERIMENTAL: relay support for networks that can't establish direct connections. relay: @@ -286,24 +292,21 @@ tun: # Configure logging level logging: - # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. - #NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some - # scenarios. Debug logging is also CPU intensive and will decrease performance overall. - # Only enable debug logging while actively investigating an issue. + # trace, debug, info, warn, or error. Default is info and is reloadable. + # fatal and panic are accepted for backwards compatibility and map to error. + #NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some + # scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall. + # Only enable debug or trace logging while actively investigating an issue. level: info - # json or text formats currently available. Default is text + # json or text formats currently available. Default is text. format: text - # Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false + # Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false. #disable_timestamp: true - # timestamp format is specified in Go time format, see: - # https://golang.org/pkg/time/#pkg-constants - # default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339) - # default when `format: text`: - # when TTY attached: seconds since beginning of execution - # otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339) - # As an example, to log as RFC3339 with millisecond precision, set to: - #timestamp_format: "2006-01-02T15:04:05.000Z07:00" + # Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable. +# The stats section is reloadable. A HUP may change the backend, toggle stats +# on or off, switch the listen/host address, or pick up new DNS for the +# configured graphite host. #stats: #type: graphite #prefix: nebula @@ -321,10 +324,12 @@ logging: # enables counter metrics for meta packets # e.g.: `messages.tx.handshake` # NOTE: `message.{tx,rx}.recv_error` is always emitted + # Not reloadable. #message_metrics: false # enables detailed counter metrics for lighthouse packets # e.g.: `lighthouse.rx.HostQuery` + # Not reloadable. #lighthouse_metrics: false # Handshake Manager Settings @@ -382,8 +387,8 @@ firewall: # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr) # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). - # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` # proto: `any`, `tcp`, `udp`, or `icmp` + # a port specification is ignored if proto is `icmp` # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass diff --git a/examples/go_service/main.go b/examples/go_service/main.go index 2f8efbfb..3f98fe3d 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -7,9 +7,9 @@ import ( "net" "os" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/service" ) @@ -64,8 +64,7 @@ pki: return err } - logger := logrus.New() - logger.Out = os.Stdout + logger := logging.NewLogger(os.Stdout) ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { diff --git a/firewall.go b/firewall.go index 45dc0691..adecbe81 100644 --- a/firewall.go +++ b/firewall.go @@ -1,11 +1,13 @@ package nebula import ( + "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "hash/fnv" + "log/slog" "net/netip" "reflect" "slices" @@ -16,7 +18,6 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -67,7 +68,7 @@ type Firewall struct { incomingMetrics firewallMetrics outgoingMetrics firewallMetrics - l *logrus.Logger + l *slog.Logger } type firewallMetrics struct { @@ -131,7 +132,7 @@ type firewallLocalCIDR struct { // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // The certificate provided should be the highest version loaded in memory. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { +func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration var tmin, tmax time.Duration @@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) if certificate == nil { certificate = cs.getCertificate(cert.Version1) @@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew case "drop": fw.InSendReject = false default: - l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`") + l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction) fw.InSendReject = false } @@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew case "drop": fw.OutSendReject = false default: - l.WithField("action", inboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`") + l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction) fw.OutSendReject = false } @@ -249,20 +250,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew // AddRule properly creates the in memory rule structure for a firewall table. func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { - // We need this rule string because we generate a hash. Removing this will break firewall reload. - ruleString := fmt.Sprintf( - "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", - incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, - ) - f.rules += ruleString + "\n" - - direction := "incoming" - if !incoming { - direction = "outgoing" - } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). - Info("Firewall rule added") - var ( ft *FirewallTable fp firewallPort @@ -280,6 +267,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort case firewall.ProtoUDP: fp = ft.UDP case firewall.ProtoICMP, firewall.ProtoICMPv6: + //ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided + if startPort != firewall.PortAny { + f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort) + } + startPort = firewall.PortAny + endPort = firewall.PortAny fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto @@ -287,6 +280,21 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort return fmt.Errorf("unknown protocol %v", proto) } + // We need this rule string because we generate a hash. Removing this will break firewall reload. + ruleString := fmt.Sprintf( + "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", + incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, + ) + f.rules += ruleString + "\n" + + direction := "incoming" + if !incoming { + direction = "outgoing" + } + f.l.Info("Firewall rule added", + "firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}, + ) + return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) } @@ -308,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } -func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { +func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { table = "firewall.inbound" @@ -349,24 +357,31 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw sPort = r.Port } - startPort, endPort, err := parsePort(sPort) - if err != nil { - return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) - } - var proto uint8 + var startPort, endPort int32 switch r.Proto { case "any": proto = firewall.ProtoAny + startPort, endPort, err = parsePort(sPort) case "tcp": proto = firewall.ProtoTCP + startPort, endPort, err = parsePort(sPort) case "udp": proto = firewall.ProtoUDP + startPort, endPort, err = parsePort(sPort) case "icmp": proto = firewall.ProtoICMP + startPort = firewall.PortAny + endPort = firewall.PortAny + if sPort != "" { + l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort) + } default: return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } + if err != nil { + return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) + } if r.Cidr != "" && r.Cidr != "any" { _, err = netip.ParsePrefix(r.Cidr) @@ -383,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw } if warning := r.sanity(); warning != nil { - l.Warnf("%s rule #%v; %s", table, i, warning) + l.Warn("firewall rule sanity check", + "table", table, + "rule", i, + "warning", warning, + ) } err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) @@ -467,7 +486,7 @@ func (f *Firewall) metrics(incoming bool) firewallMetrics { } } -// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new +// Destroy cleans up any known cyclical references so the object can be freed by GC. This should be called if a new // firewall object is created func (f *Firewall) Destroy() { //TODO: clean references if/when needed @@ -515,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, // We now know which firewall table to check against if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { - if f.l.Level >= logrus.DebugLevel { - h.logger(f.l). - WithField("fwPacket", fp). - WithField("incoming", c.incoming). - WithField("rulesVersion", f.rulesVersion). - WithField("oldRulesVersion", c.rulesVersion). - Debugln("dropping old conntrack entry, does not match new ruleset") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset", + "fwPacket", fp, + "incoming", c.incoming, + "rulesVersion", f.rulesVersion, + "oldRulesVersion", c.rulesVersion, + ) } delete(conntrack.Conns, fp) conntrack.Unlock() return false } - if f.l.Level >= logrus.DebugLevel { - h.logger(f.l). - WithField("fwPacket", fp). - WithField("incoming", c.incoming). - WithField("rulesVersion", f.rulesVersion). - WithField("oldRulesVersion", c.rulesVersion). - Debugln("keeping old conntrack entry, does match new ruleset") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset", + "fwPacket", fp, + "incoming", c.incoming, + "rulesVersion", f.rulesVersion, + "oldRulesVersion", c.rulesVersion, + ) } c.rulesVersion = f.rulesVersion @@ -660,6 +679,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer return false } + // this branch is here to catch traffic from FirewallTable.Any.match and FirewallTable.ICMP.match + if p.Protocol == firewall.ProtoICMP || p.Protocol == firewall.ProtoICMPv6 { + // port numbers are re-used for connection tracking of ICMP, + // but we don't want to actually filter on them. + return fp[firewall.PortAny].match(p, c, caPool) + } + var port int32 if p.Fragment { @@ -804,10 +830,8 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool { return true } - for _, group := range groups { - if group == "any" { - return true - } + if slices.Contains(groups, "any") { + return true } if host == "any" { @@ -917,7 +941,7 @@ type rule struct { CASha string } -func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { +func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) { r := rule{} m, ok := p.(map[string]any) @@ -948,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { return r, errors.New("group should contain a single value, an array with more than one entry was provided") } - l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) + l.Warn("group was an array with a single value, converting to simple value", + "table", table, + "rule", i, + ) m["group"] = v[0] } @@ -1018,54 +1045,56 @@ func (r *rule) sanity() error { } } + if r.Code != "" { + return fmt.Errorf("code specified as [%s]. Support for 'code' will be dropped in a future release, as it has never been functional", r.Code) + } + //todo alert on cidr-any return nil } -func parsePort(s string) (startPort, endPort int32, err error) { +func parsePort(s string) (int32, int32, error) { + var err error + const notAPort int32 = -2 if s == "any" { - startPort = firewall.PortAny - endPort = firewall.PortAny - - } else if s == "fragment" { - startPort = firewall.PortFragment - endPort = firewall.PortFragment - - } else if strings.Contains(s, `-`) { - sPorts := strings.SplitN(s, `-`, 2) - sPorts[0] = strings.Trim(sPorts[0], " ") - sPorts[1] = strings.Trim(sPorts[1], " ") - - if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { - return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) - } - - rStartPort, err := strconv.Atoi(sPorts[0]) - if err != nil { - return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) - } - - rEndPort, err := strconv.Atoi(sPorts[1]) - if err != nil { - return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) - } - - startPort = int32(rStartPort) - endPort = int32(rEndPort) - - if startPort == firewall.PortAny { - endPort = firewall.PortAny - } - - } else { + return firewall.PortAny, firewall.PortAny, nil + } + if s == "fragment" { + return firewall.PortFragment, firewall.PortFragment, nil + } + if !strings.Contains(s, `-`) { rPort, err := strconv.Atoi(s) if err != nil { - return 0, 0, fmt.Errorf("was not a number; `%s`", s) + return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s) } - startPort = int32(rPort) - endPort = startPort + return int32(rPort), int32(rPort), nil } - return + sPorts := strings.SplitN(s, `-`, 2) + for i := range sPorts { + sPorts[i] = strings.Trim(sPorts[i], " ") + } + if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { + return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) + } + + rStartPort, err := strconv.Atoi(sPorts[0]) + if err != nil { + return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) + } + + rEndPort, err := strconv.Atoi(sPorts[1]) + if err != nil { + return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) + } + + startPort := int32(rStartPort) + endPort := int32(rEndPort) + + if startPort == firewall.PortAny { + endPort = firewall.PortAny + } + + return startPort, endPort, nil } diff --git a/firewall/cache.go b/firewall/cache.go index 71b83f43..ba4b9732 100644 --- a/firewall/cache.go +++ b/firewall/cache.go @@ -1,10 +1,10 @@ package firewall import ( + "context" + "log/slog" "sync/atomic" "time" - - "github.com/sirupsen/logrus" ) // ConntrackCache is used as a local routine cache to know if a given flow @@ -15,41 +15,49 @@ type ConntrackCacheTicker struct { cacheV uint64 cacheTick atomic.Uint64 + l *slog.Logger cache ConntrackCache } -func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { +func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker { if d == 0 { return nil } c := &ConntrackCacheTicker{ + l: l, cache: ConntrackCache{}, } - go c.tick(d) + go c.tick(ctx, d) return c } -func (c *ConntrackCacheTicker) tick(d time.Duration) { +func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) { + t := time.NewTicker(d) + defer t.Stop() for { - time.Sleep(d) - c.cacheTick.Add(1) + select { + case <-ctx.Done(): + return + case <-t.C: + c.cacheTick.Add(1) + } } } // Get checks if the cache ticker has moved to the next version before returning // the map. If it has moved, we reset the map. -func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { +func (c *ConntrackCacheTicker) Get() ConntrackCache { if c == nil { return nil } if tick := c.cacheTick.Load(); tick != c.cacheV { c.cacheV = tick if ll := len(c.cache); ll > 0 { - if l.Level == logrus.DebugLevel { - l.WithField("len", ll).Debug("resetting conntrack cache") + if c.l.Enabled(context.Background(), slog.LevelDebug) { + c.l.Debug("resetting conntrack cache", "len", ll) } c.cache = make(ConntrackCache, ll) } diff --git a/firewall/cache_test.go b/firewall/cache_test.go new file mode 100644 index 00000000..ab807984 --- /dev/null +++ b/firewall/cache_test.go @@ -0,0 +1,69 @@ +package firewall + +import ( + "bytes" + "log/slog" + "strings" + "testing" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" +) + +// The tests below pin the log format produced by ConntrackCacheTicker.Get +// so changes cannot silently break what operators are grepping for. The +// ticker's internal state (cache + cacheTick) is poked directly to avoid +// racing a goroutine-driven tick in tests. + +func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker { + t.Helper() + c := &ConntrackCacheTicker{ + l: l, + cache: make(ConntrackCache, cacheLen), + } + for i := 0; i < cacheLen; i++ { + c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{} + } + c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path + return c +} + +func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 3) + c.Get() + + assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String()) +} + +func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 2) + c.Get() + + assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String())) +} + +func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo) + + c := newFixedTicker(t, l, 5) + c.Get() + + assert.Empty(t, buf.String()) +} + +func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 0) + c.Get() + + assert.Empty(t, buf.String()) +} diff --git a/firewall/packet.go b/firewall/packet.go index 40c7fc5d..2cbfb5ea 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -22,7 +22,10 @@ const ( type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr - LocalPort uint16 + // LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP. + LocalPort uint16 + // RemotePort is the source port for incoming traffic, or the destination port for outgoing. + // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier RemotePort uint16 Protocol uint8 Fragment bool @@ -46,6 +49,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) { proto = "tcp" case ProtoICMP: proto = "icmp" + case ProtoICMPv6: + proto = "icmpv6" case ProtoUDP: proto = "udp" default: diff --git a/firewall_test.go b/firewall_test.go index 1df62a81..cbf090fd 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -3,13 +3,13 @@ package nebula import ( "bytes" "errors" + "log/slog" "math" "net/netip" "testing" "time" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) { } func TestFirewall_AddRule(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) @@ -87,9 +86,10 @@ func TestFirewall_AddRule(t *testing.T) { fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) - assert.Nil(t, fw.InRules.ICMP[1].Any.Any) - assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) - assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") + //no matter what port is given for icmp, it should end up as "any" + assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any) + assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups) + assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) @@ -176,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) { } func TestFirewall_Drop(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ @@ -253,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) { } func TestFirewall_DropV6(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) @@ -484,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) { } func TestFirewall_Drop2(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -543,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) { } func TestFirewall_Drop3(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -632,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) { } func TestFirewall_Drop3V6(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) @@ -670,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) { } func TestFirewall_DropConntrackReload(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -734,10 +728,152 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) } -func TestFirewall_DropIPSpoofing(t *testing.T) { - l := test.NewLogger() +func TestFirewall_ICMPPortBehavior(t *testing.T) { ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) + + network := netip.MustParsePrefix("1.2.3.4/24") + + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + }, + InvertedGroups: map[string]struct{}{"default-group": {}}, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + vpnAddrs: []netip.Addr{network.Addr()}, + } + h.buildNetworks(myVpnNetworksTable, c.Certificate) + + cp := cert.NewCAPool() + + templ := firewall.Packet{ + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), + Protocol: firewall.ProtoICMP, + Fragment: false, + } + + t.Run("ICMP allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + }) + + t.Run("nonzero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + }) + }) + + t.Run("Any proto, some ports allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 80 + p.RemotePort = 80 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + }) + t.Run("Any proto, any port", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, allowed", func(t *testing.T) { + resetConntrack(fw) + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + }) + + t.Run("nonzero ports, allowed", func(t *testing.T) { + resetConntrack(fw) + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + //different ID is blocked + p.RemotePort++ + require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + }) + }) + +} + +func TestFirewall_DropIPSpoofing(t *testing.T) { + ob := &bytes.Buffer{} + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) @@ -900,53 +1036,53 @@ func TestNewFirewallFromConfig(t *testing.T) { cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) require.NoError(t, err) - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error - conf = config.NewC(l) - conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}} + conf = config.NewC(test.NewLogger()) + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") - conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") @@ -955,28 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) { l := test.NewLogger() // Test adding tcp rule - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) mf := &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding udp rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) + + // Test adding icmp rule no port + conf = config.NewC(test.NewLogger()) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding any rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) @@ -984,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding rule with cidr cidr := netip.MustParsePrefix("10.0.0.0/8") - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) // Test adding rule with local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) @@ -999,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding rule with cidr ipv6 cidr6 := netip.MustParsePrefix("fd00::/8") - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) // Test adding rule with any cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) // Test adding rule with junk cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with local_cidr ipv6 - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) // Test adding rule with any local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) // Test adding rule with junk local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with ca_sha - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) // Test single group - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test single groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test multiple AND groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) // Test Add error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} @@ -1082,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { } func TestFirewall_convertRule(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) // Ensure group array of 1 is converted and a warning is printed c := map[string]any{ @@ -1092,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) { } r, err := convertRule(l, c, "test", 1) - assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") + assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value") + assert.Contains(t, ob.String(), "table=test") + assert.Contains(t, ob.String(), "rule=1") require.NoError(t, err) assert.Equal(t, []string{"group1"}, r.Groups) @@ -1118,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) { } func TestFirewall_convertRuleSanity(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) noWarningPlease := []map[string]any{ {"group": "group1"}, @@ -1234,7 +1377,7 @@ type testsetup struct { fw *Firewall } -func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup { +func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup { c := dummyCert{ name: "me", networks: myPrefixes, @@ -1245,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse return newSetupFromCert(t, l, c) } -func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { +func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup { myVpnNetworksTable := new(bart.Lite) for _, prefix := range c.Networks() { myVpnNetworksTable.Insert(prefix) @@ -1262,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { t.Parallel() - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myPrefix := netip.MustParsePrefix("1.1.1.1/8") // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out diff --git a/go.mod b/go.mod index 1c564d03..0de2df7d 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,10 @@ module github.com/slackhq/nebula -go 1.25 +go 1.25.0 require ( dario.cat/mergo v1.0.2 + filippo.io/bigmod v0.1.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 @@ -12,26 +13,25 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.4 - github.com/miekg/dns v1.1.70 - github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b + github.com/miekg/dns v1.1.72 + github.com/miekg/pkcs11 v1.1.2 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.23.2 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 - github.com/sirupsen/logrus v1.9.4 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.11.1 github.com/vishvananda/netlink v1.3.1 go.yaml.in/yaml/v3 v3.0.4 - golang.org/x/crypto v0.47.0 + golang.org/x/crypto v0.50.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.49.0 - golang.org/x/sync v0.19.0 - golang.org/x/sys v0.40.0 - golang.org/x/term v0.39.0 + golang.org/x/net v0.52.0 + golang.org/x/sync v0.20.0 + golang.org/x/sys v0.43.0 + golang.org/x/term v0.42.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b - golang.zx2c4.com/wireguard/windows v0.5.3 + golang.zx2c4.com/wireguard/windows v0.6.1 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe @@ -49,7 +49,7 @@ require ( github.com/prometheus/procfs v0.16.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect - golang.org/x/mod v0.31.0 // indirect + golang.org/x/mod v0.34.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.40.0 // indirect + golang.org/x/tools v0.43.0 // indirect ) diff --git a/go.sum b/go.sum index c4613e01..aad164c7 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +filippo.io/bigmod v0.1.0 h1:UNzDk7y9ADKST+axd9skUpBQeW7fG2KrTZyOE4uGQy8= +filippo.io/bigmod v0.1.0/go.mod h1:OjOXDNlClLblvXdwgFFOQFJEocLhhtai8vGLy0JCZlI= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -83,10 +85,10 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA= -github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= -github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= -github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= +github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= +github.com/miekg/pkcs11 v1.1.2 h1:/VxmeAX5qU6Q3EwafypogwWbYryHFmF2RpkJmw3m4MQ= +github.com/miekg/pkcs11 v1.1.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -131,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= -github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= @@ -162,16 +162,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -191,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= -golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -223,8 +223,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -233,8 +233,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= -golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= -golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= +golang.zx2c4.com/wireguard/windows v0.6.1 h1:XMaKojH1Hs/raMrmnir4n35nTvzvWj7NmSYzHn2F4qU= +golang.zx2c4.com/wireguard/windows v0.6.1/go.mod h1:04aqInu5GYuTFvMuDw/rKBAF7mHrltW/3rekpfbbZDM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= diff --git a/handshake_ix.go b/handshake_ix.go index 4e04f450..a086960e 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,11 +2,12 @@ package nebula import ( "bytes" + "context" + "log/slog" "net/netip" "time" "github.com/flynn/noise" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" ) @@ -18,8 +19,11 @@ import ( func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") + f.l.Error("Failed to generate index", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + ) return false } @@ -39,28 +43,32 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { crt := cs.getCertificate(v) if crt == nil { - f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Unable to handshake with host because no certificate is available") + f.l.Error("Unable to handshake with host because no certificate is available", + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", v, + ) return false } crtHs := cs.getHandshakeBytes(v) if crtHs == nil { - f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Unable to handshake with host because no certificate handshake bytes is available") + f.l.Error("Unable to handshake with host because no certificate handshake bytes is available", + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", v, + ) return false } - ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) + ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Failed to create connection state") + f.l.Error("Failed to create connection state", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", v, + ) return false } hh.hostinfo.ConnectionState = ci @@ -76,9 +84,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("certVersion", v). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + f.l.Error("Failed to marshal handshake message", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "certVersion", v, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + ) return false } @@ -86,8 +97,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + f.l.Error("Failed to call noise.WriteMessage", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + ) return false } @@ -104,18 +118,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) cs := f.pki.getCertState() crt := cs.GetDefaultCertificate() if crt == nil { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", cs.initiatingVersion). - Error("Unable to handshake with host because no certificate is available") + f.l.Error("Unable to handshake with host because no certificate is available", + "from", via, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", cs.initiatingVersion, + ) return } - ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) + ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to create connection state") + f.l.Error("Failed to create connection state", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -124,26 +141,32 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to call noise.ReadMessage") + f.l.Error("Failed to call noise.ReadMessage", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed unmarshal handshake message") + f.l.Error("Failed unmarshal handshake message", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake did not contain a certificate") + f.l.Info("Handshake did not contain a certificate", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -154,23 +177,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) fp = "" } - e := f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("certVpnNetworks", rc.Networks()). - WithField("certFingerprint", fp) - - if f.l.Level >= logrus.DebugLevel { - e = e.WithField("cert", rc) + attrs := []slog.Attr{ + slog.Any("error", err), + slog.Any("from", via), + slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}), + slog.Any("certVpnNetworks", rc.Networks()), + slog.String("certFingerprint", fp), + } + if f.l.Enabled(context.Background(), slog.LevelDebug) { + attrs = append(attrs, slog.Any("cert", rc)) } - e.Info("Invalid certificate from host") + // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that + // callers grow conditionally, which has no pair-form equivalent. + //nolint:sloglint + f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) return } if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") + f.l.Info("public key mismatch between certificate and handshake", + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "cert", remoteCert, + ) return } @@ -178,12 +208,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) // We started off using the wrong certificate version, lets see if we can match the version that was sent to us myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) if myCertOtherVersion == nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithError(err).WithFields(m{ - "from": via, - "handshake": m{"stage": 1, "style": "ix_psk0"}, - "cert": remoteCert, - }).Debug("Might be unable to handshake with host due to missing certificate version") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Might be unable to handshake with host due to missing certificate version", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "cert", remoteCert, + ) } } else { // Record the certificate we are actually using @@ -192,10 +223,12 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("from", via). - WithField("cert", remoteCert). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("No networks in certificate") + f.l.Info("No networks in certificate", + "error", err, + "from", via, + "cert", remoteCert, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -209,12 +242,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for i, network := range vpnNetworks { if f.myVpnAddrsTable.Contains(network.Addr()) { - f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") + f.l.Error("Refusing to handshake with myself", + "vpnNetworks", vpnNetworks, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } vpnAddrs[i] = network.Addr() @@ -226,20 +262,28 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) if !via.IsRelayed { // We only want to apply the remote allow list for direct tunnels here if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - Debug("lighthouse.remote_allow_list denied incoming handshake") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", vpnAddrs, + "from", via, + ) + } return } } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") + f.l.Error("Failed to generate index", + "error", err, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -257,18 +301,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) }, } - msgRxL := f.l.WithFields(m{ - "vpnAddrs": vpnAddrs, - "from": via, - "certName": certName, - "certVersion": certVersion, - "fingerprint": fingerprint, - "issuer": issuer, - "initiatorIndex": hs.Details.InitiatorIndex, - "responderIndex": hs.Details.ResponderIndex, - "remoteIndex": h.RemoteIndex, - "handshake": m{"stage": 1, "style": "ix_psk0"}, - }) + msgRxL := f.l.With( + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) if anyVpnAddrsInCommon { msgRxL.Info("Handshake message received") @@ -280,8 +324,9 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) hs.Details.ResponderIndex = myIndex hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) if hs.Details.Cert == nil { - msgRxL.WithField("myCertVersion", ci.myCert.Version()). - Error("Unable to handshake with host because no certificate handshake bytes is available") + msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available", + "myCertVersion", ci.myCert.Version(), + ) return } @@ -291,32 +336,43 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + f.l.Error("Failed to marshal handshake message", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + f.l.Error("Failed to call noise.WriteMessage", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") + f.l.Error("Noise did not arrive at a key", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -358,13 +414,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) if !via.IsRelayed { err := f.outside.WriteTo(msg, via.UdpAddr) if err != nil { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - WithError(err).Error("Failed to send handshake message") + f.l.Error("Failed to send handshake message", + "vpnAddrs", existing.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cached", true, + "error", err, + ) } else { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") + f.l.Info("Handshake message sent", + "vpnAddrs", existing.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cached", true, + ) } return } else { @@ -374,50 +437,67 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) } hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") + f.l.Info("Handshake message sent", + "vpnAddrs", existing.vpnAddrs, + "relay", via.relayHI.vpnAddrs[0], + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cached", true, + ) return } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("oldHandshakeTime", existing.lastHandshakeTime). - WithField("newHandshakeTime", hostinfo.lastHandshakeTime). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake too old") + f.l.Info("Handshake too old", + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "oldHandshakeTime", existing.lastHandshakeTime, + "newHandshakeTime", hostinfo.lastHandshakeTime, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). - Error("Failed to add HostInfo due to localIndex collision") + f.l.Error("Failed to add HostInfo due to localIndex collision", + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "localIndex", hostinfo.localIndexId, + "collision", existing.vpnAddrs, + ) return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to add HostInfo to HostMap") + f.l.Error("Failed to add HostInfo to HostMap", + "error", err, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } } @@ -426,15 +506,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) if !via.IsRelayed { err = f.outside.WriteTo(msg, via.UdpAddr) - log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + log := f.l.With( + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) if err != nil { - log.WithError(err).Error("Failed to send handshake") + log.Error("Failed to send handshake", "error", err) } else { log.Info("Handshake message sent") } @@ -448,20 +533,29 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) // it's correctly marked as working. via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake message sent") + f.l.Info("Handshake message sent", + "vpnAddrs", vpnAddrs, + "relay", via.relayHI.vpnAddrs[0], + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) } f.connectionManager.AddTrafficWatch(hostinfo) hostinfo.remotes.RefreshFromHandshake(vpnAddrs) + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } + return } @@ -478,7 +572,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe if !via.IsRelayed { // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + ) + } return false } } @@ -486,18 +585,24 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Error("Failed to call noise.ReadMessage") + f.l.Error("Failed to call noise.ReadMessage", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "header", h, + ) // We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Error("Noise did not arrive at a key") + f.l.Error("Noise did not arrive at a key", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) // This should be impossible in IX but just in case, if we get here then there is no chance to recover // the handshake state machine. Tear it down @@ -507,8 +612,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + f.l.Error("Failed unmarshal handshake message", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true @@ -516,10 +625,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake did not contain a certificate") + f.l.Info("Handshake did not contain a certificate", + "error", err, + "from", via, + "vpnAddrs", hostinfo.vpnAddrs, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) return true } @@ -530,32 +641,41 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe fp = "" } - e := f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("certFingerprint", fp). - WithField("certVpnNetworks", rc.Networks()) - - if f.l.Level >= logrus.DebugLevel { - e = e.WithField("cert", rc) + attrs := []slog.Attr{ + slog.Any("error", err), + slog.Any("from", via), + slog.Any("vpnAddrs", hostinfo.vpnAddrs), + slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}), + slog.String("certFingerprint", fp), + slog.Any("certVpnNetworks", rc.Networks()), + } + if f.l.Enabled(context.Background(), slog.LevelDebug) { + attrs = append(attrs, slog.Any("cert", rc)) } - e.Info("Invalid certificate from host") + // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that + // callers grow conditionally, which has no pair-form equivalent. + //nolint:sloglint + f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) return true } if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") + f.l.Info("public key mismatch between certificate and handshake", + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cert", remoteCert, + ) return true } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("cert", remoteCert). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("No networks in certificate") + f.l.Info("No networks in certificate", + "error", err, + "from", via, + "vpnAddrs", hostinfo.vpnAddrs, + "cert", remoteCert, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) return true } @@ -596,12 +716,14 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe // Ensure the right host responded if !correctHostResponded { - f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). - WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Incorrect host responded to handshake") + f.l.Info("Incorrect host responded to handshake", + "intendedVpnAddrs", hostinfo.vpnAddrs, + "haveVpnNetworks", vpnNetworks, + "from", via, + "certName", certName, + "certVersion", certVersion, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) // Release our old handshake from pending, it should not continue f.handshakeManager.DeleteHostInfo(hostinfo) @@ -613,10 +735,11 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes.BlockRemote(via) - f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). - WithField("vpnNetworks", vpnNetworks). - WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). - Info("Blocked addresses for handshakes") + f.l.Info("Blocked addresses for handshakes", + "blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(), + "vpnNetworks", vpnNetworks, + "remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()), + ) // Swap the packet store to benefit the original intended recipient newHH.packetStore = hh.packetStore @@ -634,15 +757,20 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("durationNs", duration). - WithField("sentCachedPackets", len(hh.packetStore)) + msgRxL := f.l.With( + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "durationNs", duration, + "sentCachedPackets", len(hh.packetStore), + ) if anyVpnAddrsInCommon { msgRxL.Info("Handshake message received") } else { @@ -658,8 +786,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Sending stored packets", + "count", len(hh.packetStore), + ) } if len(hh.packetStore) > 0 { @@ -674,5 +804,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe hostinfo.remotes.RefreshFromHandshake(vpnAddrs) f.metricHandshakes.Update(duration) + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } + return false } diff --git a/handshake_manager.go b/handshake_manager.go index 8b1ce839..8040ec2e 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -6,13 +6,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "log/slog" "net/netip" "slices" "sync" "time" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -59,7 +59,7 @@ type HandshakeManager struct { metricInitiated metrics.Counter metricTimedOut metrics.Counter f *Interface - l *logrus.Logger + l *slog.Logger // can be used to trigger outbound handshake for the given vpnIp trigger chan netip.Addr @@ -78,32 +78,32 @@ type HandshakeHostInfo struct { hostinfo *HostInfo } -func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { +func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { if len(hh.packetStore) < 100 { tempPacket := make([]byte, len(packet)) copy(tempPacket, packet) hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket}) - if l.Level >= logrus.DebugLevel { - hh.hostinfo.logger(l). - WithField("length", len(hh.packetStore)). - WithField("stored", true). - Debugf("Packet store") + if l.Enabled(context.Background(), slog.LevelDebug) { + hh.hostinfo.logger(l).Debug("Packet store", + "length", len(hh.packetStore), + "stored", true, + ) } } else { m.dropped.Inc(1) - if l.Level >= logrus.DebugLevel { - hh.hostinfo.logger(l). - WithField("length", len(hh.packetStore)). - WithField("stored", false). - Debugf("Packet store") + if l.Enabled(context.Background(), slog.LevelDebug) { + hh.hostinfo.logger(l).Debug("Packet store", + "length", len(hh.packetStore), + "stored", false, + ) } } } -func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { +func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, @@ -140,7 +140,7 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head // First remote allow list check before we know the vpnIp if !via.IsRelayed { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { - hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") + hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via) return } } @@ -183,12 +183,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo := hh.hostinfo // If we are out of time, clean up if hh.counter >= hm.config.retries { - hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())). - WithField("initiatorIndex", hh.hostinfo.localIndexId). - WithField("remoteIndex", hh.hostinfo.remoteIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("durationNs", time.Since(hh.startTime).Nanoseconds()). - Info("Handshake timed out") + hh.hostinfo.logger(hm.l).Info("Handshake timed out", + "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), + "initiatorIndex", hh.hostinfo.localIndexId, + "remoteIndex", hh.hostinfo.remoteIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "durationNs", time.Since(hh.startTime).Nanoseconds(), + ) hm.metricTimedOut.Inc(1) hm.DeleteHostInfo(hostinfo) return @@ -241,10 +242,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { - hostinfo.logger(hm.l).WithField("udpAddr", addr). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake message") + hostinfo.logger(hm.l).Error("Failed to send handshake message", + "udpAddr", addr, + "initiatorIndex", hostinfo.localIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "error", err, + ) } else { sentTo = append(sentTo, addr) @@ -254,19 +257,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, // so only log when the list of remotes has changed if remotesHaveChanged { - hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake message sent") - } else if hm.l.Level >= logrus.DebugLevel { - hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Debug("Handshake message sent") + hostinfo.logger(hm.l).Info("Handshake message sent", + "udpAddrs", sentTo, + "initiatorIndex", hostinfo.localIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) + } else if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(hm.l).Debug("Handshake message sent", + "udpAddrs", sentTo, + "initiatorIndex", hostinfo.localIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) } if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { - hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") + hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay through the host I'm trying to connect to @@ -281,7 +286,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") + hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String()) hm.f.Handshake(relay) continue } @@ -292,7 +297,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { - hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") + hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) } m := NebulaControl{ @@ -326,17 +331,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered msg, err := m.Marshal() if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") + hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], - "relayTo": vpnIp, - "initiatorRelayIndex": idx, - "relay": relay}). - Info("send CreateRelayRequest") + hm.l.Info("send CreateRelayRequest", + "relayFrom", hm.f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", idx, + "relay", relay, + ) } } continue @@ -344,14 +347,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered switch existingRelay.State { case Established: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String()) hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Disestablished: // Mark this relay as 'requested' relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) fallthrough case Requested: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String()) // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, @@ -383,28 +386,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } msg, err := m.Marshal() if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") + hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) } else { // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], - "relayTo": vpnIp, - "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": relay}). - Info("send CreateRelayRequest") + hm.l.Info("send CreateRelayRequest", + "relayFrom", hm.f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", existingRelay.LocalIndex, + "relay", relay, + ) } case PeerRequested: // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. fallthrough default: - hostinfo.logger(hm.l). - WithField("vpnIp", vpnIp). - WithField("state", existingRelay.State). - WithField("relay", relay). - Errorf("Relay unexpected state") + hostinfo.logger(hm.l).Error("Relay unexpected state", + "vpnIp", vpnIp, + "state", existingRelay.State, + "relay", relay, + ) } } @@ -549,9 +550,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). - Info("New host shadows existing host remoteIndex") + hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", + "remoteIndex", hostinfo.remoteIndexId, + "collision", existingRemoteIndex.vpnAddrs, + ) } hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) @@ -571,9 +573,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). - Info("New host shadows existing host remoteIndex") + hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", + "remoteIndex", hostinfo.remoteIndexId, + "collision", existingRemoteIndex.vpnAddrs, + ) } // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. @@ -590,7 +593,7 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { hm.Lock() defer hm.Unlock() - for i := 0; i < 32; i++ { + for range 32 { index, err := generateIndex(hm.l) if err != nil { return err @@ -629,10 +632,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { hm.indexes = map[uint32]*HandshakeHostInfo{} } - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), - "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). - Debug("Pending hostmap hostInfo deleted") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Pending hostmap hostInfo deleted", + "hostMap", m{"mapTotalSize": len(hm.vpnIps), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}, + ) } } @@ -700,7 +704,7 @@ func (hm *HandshakeManager) EmitStats() { // Utility functions below -func generateIndex(l *logrus.Logger) (uint32, error) { +func generateIndex(l *slog.Logger) (uint32, error) { b := make([]byte, 4) // Let zero mean we don't know the ID, so don't generate zero @@ -708,16 +712,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) { for index == 0 { _, err := rand.Read(b) if err != nil { - l.Errorln(err) + l.Error("Failed to generate index", "error", err) return 0, err } index = binary.BigEndian.Uint32(b) } - if l.Level >= logrus.DebugLevel { - l.WithField("index", index). - Debug("Generated index") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Generated index", "index", index) } return index, nil } diff --git a/hostmap.go b/hostmap.go index 7e2939e0..08acd1be 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,9 +1,11 @@ package nebula import ( + "context" "encoding/json" "errors" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -13,10 +15,10 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" ) const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address @@ -60,7 +62,7 @@ type HostMap struct { RemoteIndexes map[uint32]*HostInfo Hosts map[netip.Addr]*HostInfo preferredRanges atomic.Pointer[[]netip.Prefix] - l *logrus.Logger + l *slog.Logger } // For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay @@ -313,7 +315,7 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { +func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap { hm := newHostMap(l) hm.reload(c, true) @@ -321,13 +323,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { hm.reload(c, false) }) - l.WithField("preferredRanges", hm.GetPreferredRanges()). - Info("Main HostMap created") + l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges()) return hm } -func newHostMap(l *logrus.Logger) *HostMap { +func newHostMap(l *slog.Logger) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, @@ -346,7 +347,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) { preferredRange, err := netip.ParsePrefix(rawPreferredRange) if err != nil { - hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") + hm.l.Warn("Failed to parse preferred ranges, ignoring", + "error", err, + "range", rawPreferredRanges, + ) continue } @@ -355,7 +359,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) { oldRanges := hm.preferredRanges.Swap(&preferredRanges) if !initial { - hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed") + hm.l.Info("preferred_ranges changed", + "oldPreferredRanges", *oldRanges, + "newPreferredRanges", preferredRanges, + ) } } } @@ -488,10 +495,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad hm.Indexes = map[uint32]*HostInfo{} } - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), - "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). - Debug("Hostmap hostInfo deleted") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Hostmap hostInfo deleted", + "hostMap", m{"mapTotalSize": len(hm.Hosts), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}, + ) } if isLastHostinfo { @@ -604,9 +612,9 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. // If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { - if f.serveDns { + if f.dnsServer != nil { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) + f.dnsServer.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } for _, addr := range hostinfo.vpnAddrs { hm.unlockedInnerAddHostInfo(addr, hostinfo, f) @@ -615,10 +623,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). - Debug("Hostmap vpnIp added") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Hostmap vpnIp added", + "hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}, + ) } } @@ -784,18 +793,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica } } -func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { +// logger returns a derived slog.Logger with per-hostinfo fields pre-bound. +func (i *HostInfo) logger(l *slog.Logger) *slog.Logger { if i == nil { - return logrus.NewEntry(l) + return l } - li := l.WithField("vpnAddrs", i.vpnAddrs). - WithField("localIndex", i.localIndexId). - WithField("remoteIndex", i.remoteIndexId) + li := l.With( + "vpnAddrs", i.vpnAddrs, + "localIndex", i.localIndexId, + "remoteIndex", i.remoteIndexId, + ) if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { - li = li.WithField("certName", peerCert.Certificate.Name()) + li = li.With("certName", peerCert.Certificate.Name()) } } @@ -804,14 +816,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { +func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage var finalAddrs []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) - if l.Level >= logrus.TraceLevel { - l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName") + if l.Enabled(context.Background(), logging.LevelTrace) { + l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName", + "interfaceName", i.Name, + "allow", allow, + ) } if !allow { @@ -829,8 +844,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { } if !addr.IsValid() { - if l.Level >= logrus.DebugLevel { - l.WithField("localAddr", rawAddr).Debug("addr was invalid") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("addr was invalid", "localAddr", rawAddr) } continue } @@ -838,8 +853,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { isAllowed := allowList.Allow(addr) - if l.Level >= logrus.TraceLevel { - l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") + if l.Enabled(context.Background(), logging.LevelTrace) { + l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow", + "localAddr", addr, + "allowed", isAllowed, + ) } if !isAllowed { continue diff --git a/hostmap_test.go b/hostmap_test.go index e34a4ad0..2bd7bd43 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { func TestHostMap_reload(t *testing.T) { l := test.NewLogger() - c := config.NewC(l) + c := config.NewC(test.NewLogger()) hm := NewHostMapFromConfig(l, c) diff --git a/hostmap_tester.go b/hostmap_tester.go index fe40c533..a6ac6d44 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -1,5 +1,4 @@ //go:build e2e_testing -// +build e2e_testing package nebula diff --git a/inside.go b/inside.go index 0d53f952..68cb38ec 100644 --- a/inside.go +++ b/inside.go @@ -1,9 +1,10 @@ package nebula import ( + "context" + "log/slog" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" @@ -14,8 +15,11 @@ import ( func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Error while validating outbound packet", + "packet", packet, + "error", err, + ) } return } @@ -35,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) if err != nil { - f.l.WithError(err).Error("Failed to forward to tun") + f.l.Error("Failed to forward to tun", "error", err) } } // Otherwise, drop. On linux, we should never see these packets - Linux @@ -54,10 +58,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet if hostinfo == nil { f.rejectInside(packet, out, q) - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddr", fwPacket.RemoteAddr). - WithField("fwPacket", fwPacket). - Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", + "vpnAddr", fwPacket.RemoteAddr, + "fwPacket", fwPacket, + ) } return } @@ -72,11 +77,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } else { f.rejectInside(packet, out, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l). - WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping outbound packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping outbound packet", + "fwPacket", fwPacket, + "reason", dropReason, + ) } } } @@ -93,7 +98,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) { _, err := f.readers[q].Write(out) if err != nil { - f.l.WithError(err).Error("Failed to write to tun") + f.l.Error("Failed to write to tun", "error", err) } } @@ -108,11 +113,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * } if len(out) > iputil.MaxRejectPacketSize { - if f.l.GetLevel() >= logrus.InfoLevel { - f.l. - WithField("packet", packet). - WithField("outPacket", out). - Info("rejectOutside: packet too big, not sending") + if f.l.Enabled(context.Background(), slog.LevelInfo) { + f.l.Info("rejectOutside: packet too big, not sending", + "packet", packet, + "outPacket", out, + ) } return } @@ -184,10 +189,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac // This would also need to interact with unsafe_route updates through reloading the config or // use of the use_system_route_table option - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("destination", destinationAddr). - WithField("originalGateway", gatewayAddr). - Debugln("Calculated gateway for ECMP not available, attempting other gateways") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways", + "destination", destinationAddr, + "originalGateway", gatewayAddr, + ) } for i := range gateways { @@ -213,17 +219,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp fp := &firewall.Packet{} err := newPacket(p, false, fp) if err != nil { - f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err) + f.l.Warn("error while parsing outgoing packet for firewall check", "error", err) return } // check if packet is in outbound fw rules dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("fwPacket", fp). - WithField("reason", dropReason). - Debugln("dropping cached packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping cached packet", + "fwPacket", fp, + "reason", dropReason, + ) } return } @@ -239,9 +246,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message }) if hostInfo == nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddr", vpnAddr). - Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes", + "vpnAddr", vpnAddr, + ) } return } @@ -297,12 +305,12 @@ func (f *Interface) SendVia(via *HostInfo, if noiseutil.EncryptLockNeeded { via.ConnectionState.writeLock.Unlock() } - via.logger(f.l). - WithField("outCap", cap(out)). - WithField("payloadLen", len(ad)). - WithField("headerLen", len(out)). - WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()). - Error("SendVia out buffer not large enough for relay") + via.logger(f.l).Error("SendVia out buffer not large enough for relay", + "outCap", cap(out), + "payloadLen", len(ad), + "headerLen", len(out), + "cipherOverhead", via.ConnectionState.eKey.Overhead(), + ) return } @@ -322,12 +330,12 @@ func (f *Interface) SendVia(via *HostInfo, via.ConnectionState.writeLock.Unlock() } if err != nil { - via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") + via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err) return } err = f.writers[0].WriteTo(out, via.remote) if err != nil { - via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") + via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err) } f.connectionManager.RelayUsed(relay.LocalIndex) } @@ -366,8 +374,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Lighthouse update triggered for punch due to rebind counter", + "vpnAddrs", hostinfo.vpnAddrs, + ) } } @@ -377,24 +387,30 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType ci.writeLock.Unlock() } if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).WithField("counter", c). - WithField("attemptedCounter", c). - Error("Failed to encrypt outgoing packet") + hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", + "error", err, + "udpAddr", remote, + "counter", c, + "attemptedCounter", c, + ) return } if remote.IsValid() { err = f.writers[q].WriteTo(out, remote) if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") + hostinfo.logger(f.l).Error("Failed to write outgoing packet", + "error", err, + "udpAddr", remote, + ) } } else if hostinfo.remote.IsValid() { err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") + hostinfo.logger(f.l).Error("Failed to write outgoing packet", + "error", err, + "udpAddr", remote, + ) } } else { // Try to send via a relay @@ -402,7 +418,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) - hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") + hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", + "relay", relayIP, + "error", err, + ) continue } f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) diff --git a/inside_bsd.go b/inside_bsd.go index c9c7730d..dc847878 100644 --- a/inside_bsd.go +++ b/inside_bsd.go @@ -1,5 +1,4 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package nebula diff --git a/inside_generic.go b/inside_generic.go index 0bb2345a..bdcc1a6a 100644 --- a/inside_generic.go +++ b/inside_generic.go @@ -1,5 +1,4 @@ //go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd -// +build !darwin,!dragonfly,!freebsd,!netbsd,!openbsd package nebula diff --git a/interface.go b/interface.go index 0027c18a..f4dbfc19 100644 --- a/interface.go +++ b/interface.go @@ -6,15 +6,15 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" - "os" - "runtime" + "sync" "sync/atomic" "time" "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -31,7 +31,7 @@ type InterfaceConfig struct { pki *PKI Cipher string Firewall *Firewall - ServeDns bool + DnsServer *dnsServer HandshakeManager *HandshakeManager lightHouse *LightHouse connectionManager *connectionManager @@ -48,7 +48,7 @@ type InterfaceConfig struct { reQueryWait time.Duration ConntrackCacheTimeout time.Duration - l *logrus.Logger + l *slog.Logger } type Interface struct { @@ -59,7 +59,7 @@ type Interface struct { firewall *Firewall connectionManager *connectionManager handshakeManager *HandshakeManager - serveDns bool + dnsServer *dnsServer createTime time.Time lightHouse *LightHouse myBroadcastAddrsTable *bart.Lite @@ -87,14 +87,22 @@ type Interface struct { conntrackCacheTimeout time.Duration + ctx context.Context writers []udp.Conn readers []io.ReadWriteCloser + wg sync.WaitGroup + + // fatalErr holds the first unexpected reader error that caused shutdown. + // nil means "no fatal error" (yet) + fatalErr atomic.Pointer[error] + // triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr + triggerShutdown func() metricHandshakes metrics.Histogram messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics - l *logrus.Logger + l *slog.Logger } type EncWriter interface { @@ -165,12 +173,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { cs := c.pki.getCertState() ifce := &Interface{ + ctx: ctx, pki: c.pki, hostMap: c.HostMap, outside: c.Outside, inside: c.Inside, firewall: c.Firewall, - serveDns: c.ServeDns, + dnsServer: c.DnsServer, handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, @@ -211,19 +220,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { // activate creates the interface on the host. After the interface is created, any // other services that want to bind listeners to its IP may do so successfully. However, // the interface isn't going to process anything until run() is called. -func (f *Interface) activate() { +func (f *Interface) activate() error { // actually turn on tun dev addr, err := f.outside.LocalAddr() if err != nil { - f.l.WithError(err).Error("Failed to get udp listen address") + f.l.Error("Failed to get udp listen address", "error", err) } - f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). - WithField("build", f.version).WithField("udpAddr", addr). - WithField("boringcrypto", boringEnabled()). - WithField("fips140", fips140.Enabled()). - Info("Nebula interface is active") + f.l.Info("Nebula interface is active", + "interface", f.inside.Name(), + "networks", f.myVpnNetworks, + "build", f.version, + "udpAddr", addr, + "boringcrypto", boringEnabled(), + "fips140", fips140.Enabled(), + ) if f.routines > 1 { if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() { @@ -240,33 +252,58 @@ func (f *Interface) activate() { if i > 0 { reader, err = f.inside.NewMultiQueueReader() if err != nil { - f.l.Fatal(err) + return err } } f.readers[i] = reader } - if err := f.inside.Activate(); err != nil { + f.wg.Add(1) // for us to wait on Close() to return + if err = f.inside.Activate(); err != nil { + f.wg.Done() f.inside.Close() - f.l.Fatal(err) + return err } + + return nil } -func (f *Interface) run() { +func (f *Interface) run() (func() error, error) { // Launch n queues to read packets from udp for i := 0; i < f.routines; i++ { - go f.listenOut(i) + f.wg.Go(func() { + f.listenOut(i) + }) } // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { - go f.listenIn(f.readers[i], i) + f.wg.Go(func() { + f.listenIn(f.readers[i], i) + }) + } + + return func() error { + f.wg.Wait() + if e := f.fatalErr.Load(); e != nil { + return *e + } + return nil + }, nil +} + +// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one +func (f *Interface) onFatal(err error) { + swapped := f.fatalErr.CompareAndSwap(nil, &err) + if !swapped { + return + } + if f.triggerShutdown != nil { + f.triggerShutdown() } } func (f *Interface) listenOut(i int) { - runtime.LockOSThread() - var li udp.Conn if i > 0 { li = f.writers[i] @@ -274,42 +311,47 @@ func (f *Interface) listenOut(i int) { li = f.outside } - ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() plaintext := make([]byte, udp.MTU) h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) }) + + if err != nil && !f.closed.Load() { + f.l.Error("Error while reading inbound packet, closing", "error", err) + f.onFatal(err) + } + + f.l.Debug("underlay reader is done", "reader", i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { - runtime.LockOSThread() - packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { - return + if !f.closed.Load() { + f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i) + f.onFatal(err) } - - f.l.WithError(err).Error("Error while reading outbound packet") - // This only seems to happen when something fatal happens to the fd, so exit. - os.Exit(2) + break } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) } + + f.l.Debug("overlay reader is done", "reader", i) } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { @@ -329,7 +371,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) { if initial || c.HasChanged("pki.disconnect_invalid") { f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true)) if !initial { - f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load()) + f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load()) } } } @@ -343,7 +385,7 @@ func (f *Interface) reloadFirewall(c *config.C) { fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { - f.l.WithError(err).Error("Error while creating firewall during reload") + f.l.Error("Error while creating firewall during reload", "error", err) return } @@ -356,10 +398,11 @@ func (f *Interface) reloadFirewall(c *config.C) { // If rulesVersion is back to zero, we have wrapped all the way around. Be // safe and just reset conntrack in this case. if fw.rulesVersion == 0 { - f.l.WithField("firewallHashes", fw.GetRuleHashes()). - WithField("oldFirewallHashes", oldFw.GetRuleHashes()). - WithField("rulesVersion", fw.rulesVersion). - Warn("firewall rulesVersion has overflowed, resetting conntrack") + f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack", + "firewallHashes", fw.GetRuleHashes(), + "oldFirewallHashes", oldFw.GetRuleHashes(), + "rulesVersion", fw.rulesVersion, + ) } else { fw.Conntrack = conntrack } @@ -367,10 +410,11 @@ func (f *Interface) reloadFirewall(c *config.C) { f.firewall = fw oldFw.Destroy() - f.l.WithField("firewallHashes", fw.GetRuleHashes()). - WithField("oldFirewallHashes", oldFw.GetRuleHashes()). - WithField("rulesVersion", fw.rulesVersion). - Info("New firewall has been installed") + f.l.Info("New firewall has been installed", + "firewallHashes", fw.GetRuleHashes(), + "oldFirewallHashes", oldFw.GetRuleHashes(), + "rulesVersion", fw.rulesVersion, + ) } func (f *Interface) reloadSendRecvError(c *config.C) { @@ -392,8 +436,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) { } } - f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()). - Info("Loaded send_recv_error config") + f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String()) } } @@ -416,8 +459,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) { } } - f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()). - Info("Loaded accept_recv_error config") + f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String()) } } @@ -484,23 +526,23 @@ func (f *Interface) GetCertState() *CertState { } func (f *Interface) Close() error { + var errs []error f.closed.Store(true) - for _, u := range f.writers { + // Release the udp readers + for i, u := range f.writers { err := u.Close() if err != nil { - f.l.WithError(err).Error("Error while closing udp socket") - } - } - for i, r := range f.readers { - if i == 0 { - continue // f.readers[0] is f.inside, which we want to save for last - } - if err := r.Close(); err != nil { - f.l.WithError(err).Error("Error while closing tun reader") + f.l.Error("Error while closing udp socket", "error", err, "writer", i) + errs = append(errs, err) } } - // Release the tun device - return f.inside.Close() + // Release the tun device (closing the tun also closes all readers) + closeErr := f.inside.Close() + if closeErr != nil { + errs = append(errs, closeErr) + } + f.wg.Done() + return errors.Join(errs...) } diff --git a/lighthouse.go b/lighthouse.go index 1510b942..6034e68c 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -15,10 +16,10 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) @@ -69,18 +70,19 @@ type LightHouse struct { // Addr's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]netip.Addr] - queryChan chan netip.Addr + updateTrigger chan struct{} + queryChan chan netip.Addr calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter - l *logrus.Logger + l *slog.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -105,6 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, nebulaPort: nebulaPort, punchConn: pc, punchy: p, + updateTrigger: make(chan struct{}, 1), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } @@ -131,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, case *util.ContextualError: v.Log(l) case error: - l.WithError(err).Error("failed to reload lighthouse") + l.Error("failed to reload lighthouse", "error", err) } }) @@ -203,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used addr := addrs[0].Unmap() if lh.myVpnNetworksTable.Contains(addr) { - lh.l.WithField("addr", rawAddr).WithField("entry", i+1). - Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") + lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range", + "addr", rawAddr, + "entry", i+1, + ) continue } @@ -222,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10))) if !initial { - lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load()) + lh.l.Info("lighthouse.interval changed", + "interval", lh.interval.Load(), + ) if lh.updateCancel != nil { // May not always have a running routine @@ -316,6 +323,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { if !initial { //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic lh.l.Info("lighthouse.hosts has changed") + lh.TriggerUpdate() } } @@ -333,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { for _, v := range c.GetStringSlice("relay.relays", nil) { configRIP, err := netip.ParseAddr(v) if err != nil { - lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") + lh.l.Warn("Parse relay from config failed", + "relay", v, + "error", err, + ) } else { - lh.l.WithField("relay", v).Info("Read relay from config") + lh.l.Info("Read relay from config", "relay", v) relaysForMe = append(relaysForMe, configRIP) } } @@ -360,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) { } if !lh.myVpnNetworksTable.Contains(addr) { - lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). - Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not") + lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not", + "vpnAddr", addr, + "networks", lh.myVpnNetworks, + ) } out[i] = addr } @@ -432,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc } if !lh.myVpnNetworksTable.Contains(vpnAddr) { - lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). - Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work") + lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work", + "vpnAddr", vpnAddr, + "networks", lh.myVpnNetworks, + "entry", i+1, + ) } vals, ok := v.([]any) @@ -534,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { lh.Lock() rm, ok := lh.addrMap[allVpnAddrs[0]] if ok { + debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug) for _, addr := range allVpnAddrs { srm := lh.addrMap[addr] if srm == rm { delete(lh.addrMap, addr) - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", addr) + if debugEnabled { + lh.l.Debug("deleting from lighthouse", "vpnAddr", addr) } } } @@ -656,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddrs", vpnAddrs, + "udpAddr", to, + "allow", allow, + ) } if !allow { return false @@ -675,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { udpAddr := protoV4AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddr", vpnAddr, + "udpAddr", udpAddr, + "allow", allow, + ) } if !allow { @@ -695,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { udpAddr := protoV6AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddr", vpnAddr, + "udpAddr", udpAddr, + "allow", allow, + ) } if !allow { @@ -713,21 +739,14 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { l := lh.GetLighthouses() - for i := range l { - if l[i] == vpnAddr { - return true - } - } - return false + return slices.Contains(l, vpnAddr) } func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool { l := lh.GetLighthouses() for i := range vpnAddrs { - for j := range l { - if l[j] == vpnAddrs[i] { - return true - } + if slices.Contains(l, vpnAddrs[i]) { + return true } } return false @@ -779,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { if v == cert.Version1 { if !addr.Is4() { - lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). - Error("Can't query lighthouse for v6 address using a v1 protocol") + lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol", + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } @@ -791,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { v1Query, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("queryVpnAddr", addr). - WithField("lighthouseAddr", lhVpnAddr). - Error("Failed to marshal lighthouse v1 query payload") + lh.l.Error("Failed to marshal lighthouse v1 query payload", + "error", err, + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -808,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { v2Query, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("queryVpnAddr", addr). - WithField("lighthouseAddr", lhVpnAddr). - Error("Failed to marshal lighthouse v2 query payload") + lh.l.Error("Failed to marshal lighthouse v2 query payload", + "error", err, + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -819,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { queried++ } else { - lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) + lh.l.Debug("unsupported protocol version", + "op", "query", + "queryVpnAddr", addr, + "version", v, + ) continue } } @@ -848,11 +877,24 @@ func (lh *LightHouse) StartUpdateWorker() { return case <-clockSource.C: continue + case <-lh.updateTrigger: + continue } } }() } +// TriggerUpdate requests an immediate lighthouse update. This is a non-blocking +// operation intended to be called after a handshake completes with a lighthouse, +// so the lighthouse has our current addresses without waiting for the next +// periodic update. +func (lh *LightHouse) TriggerUpdate() { + select { + case lh.updateTrigger <- struct{}{}: + default: + } +} + func (lh *LightHouse) SendUpdate() { var v4 []*V4AddrPort var v6 []*V6AddrPort @@ -898,8 +940,9 @@ func (lh *LightHouse) SendUpdate() { if v == cert.Version1 { if v1Update == nil { if !lh.myVpnNetworks[0].Addr().Is4() { - lh.l.WithField("lighthouseAddr", lhVpnAddr). - Warn("cannot update lighthouse using v1 protocol without an IPv4 address") + lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address", + "lighthouseAddr", lhVpnAddr, + ) continue } var relays []uint32 @@ -923,8 +966,10 @@ func (lh *LightHouse) SendUpdate() { v1Update, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). - Error("Error while marshaling for lighthouse v1 update") + lh.l.Error("Error while marshaling for lighthouse v1 update", + "error", err, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -950,8 +995,10 @@ func (lh *LightHouse) SendUpdate() { v2Update, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). - Error("Error while marshaling for lighthouse v2 update") + lh.l.Error("Error while marshaling for lighthouse v2 update", + "error", err, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -960,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() { updated++ } else { - lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) + lh.l.Debug("unsupported protocol version", + "op", "update", + "version", v, + ) continue } } @@ -974,7 +1024,7 @@ type LightHouseHandler struct { out []byte pb []byte meta *NebulaMeta - l *logrus.Logger + l *slog.Logger } func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { @@ -1023,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). - Error("Failed to unmarshal lighthouse packet") + lhh.l.Error("Failed to unmarshal lighthouse packet", + "error", err, + "vpnAddrs", fromVpnAddrs, + "udpAddr", rAddr, + ) return } if n.Details == nil { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). - Error("Invalid lighthouse update") + lhh.l.Error("Invalid lighthouse update", + "vpnAddrs", fromVpnAddrs, + "udpAddr", rAddr, + ) return } @@ -1058,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I don't answer queries, but received from: ", addr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("I don't answer queries, but received one", "from", addr) } return } queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details). - Debugln("Dropping malformed HostQuery") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Dropping malformed HostQuery", + "from", fromVpnAddrs, + "details", n.Details, + ) } return } if useVersion == cert.Version1 && queryVpnAddr.Is6() { // this case really shouldn't be possible to represent, but reject it anyway. - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr). - Debugln("invalid vpn addr for v1 handleHostQuery") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("invalid vpn addr for v1 handleHostQuery", + "vpnAddrs", fromVpnAddrs, + "queryVpnAddr", queryVpnAddr, + ) } return } @@ -1101,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") + lhh.l.Error("Failed to marshal lighthouse host query reply", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1129,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd if ok { whereToPunch = newDest } else { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("unable to punch to host, no addresses in common", + "to", crt.Networks(), + ) } } } @@ -1156,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") + lhh.l.Error("Failed to marshal lighthouse host was queried for", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1198,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) } } else { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("version", v).Debug("unsupported protocol version") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("unsupported protocol version", + "op", "coalesceAnswers", + "version", v, + ) } } } @@ -1212,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Error("dropping malformed HostQueryReply", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) } return } @@ -1238,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs) } return } @@ -1262,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp //Simple check that the host sent this not someone else, if detailsVpnAddr is filled if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Host sent invalid update", + "vpnAddrs", fromVpnAddrs, + "answer", detailsVpnAddr, + ) } return } @@ -1285,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp switch useVersion { case cert.Version1: if !fromVpnAddrs[0].Is4() { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message", + "vpnAddrs", fromVpnAddrs, + ) return } vpnAddrB := fromVpnAddrs[0].As4() @@ -1293,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp case cert.Version2: // do nothing, we want to send a blank message default: - lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") + lhh.l.Error("invalid protocol version", "useVersion", useVersion) return } ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") + lhh.l.Error("Failed to marshal lighthouse host update ack", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1316,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("dropping invalid HostPunchNotification", + "details", n.Details, + "error", err, + ) } return } @@ -1334,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn lhh.lh.punchConn.WriteTo(empty, vpnPeer) }() - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Punching", + "vpnPeer", vpnPeer, + "logVpnAddr", logVpnAddr, + ) } } @@ -1360,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn if lhh.lh.punchy.GetRespond() { go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Sending a nebula test packet", + "vpnAddr", detailsVpnAddr, + ) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine diff --git a/logger.go b/logger.go deleted file mode 100644 index aaf6f29c..00000000 --- a/logger.go +++ /dev/null @@ -1,45 +0,0 @@ -package nebula - -import ( - "fmt" - "strings" - "time" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" -) - -func configLogger(l *logrus.Logger, c *config.C) error { - // set up our logging level - logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) - if err != nil { - return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) - } - l.SetLevel(logLevel) - - disableTimestamp := c.GetBool("logging.disable_timestamp", false) - timestampFormat := c.GetString("logging.timestamp_format", "") - fullTimestamp := (timestampFormat != "") - if timestampFormat == "" { - timestampFormat = time.RFC3339 - } - - logFormat := strings.ToLower(c.GetString("logging.format", "text")) - switch logFormat { - case "text": - l.Formatter = &logrus.TextFormatter{ - TimestampFormat: timestampFormat, - FullTimestamp: fullTimestamp, - DisableTimestamp: disableTimestamp, - } - case "json": - l.Formatter = &logrus.JSONFormatter{ - TimestampFormat: timestampFormat, - DisableTimestamp: disableTimestamp, - } - default: - return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) - } - - return nil -} diff --git a/logging/logger.go b/logging/logger.go new file mode 100644 index 00000000..bbc10bb3 --- /dev/null +++ b/logging/logger.go @@ -0,0 +1,233 @@ +// Package logging wires the nebula runtime-reconfigurable slog handler used +// by nebula.Main and the nebula CLI binaries. Callers build a logger with +// NewLogger, then call ApplyConfig at startup and from a config reload +// callback to push logging.level, logging.format, and +// logging.disable_timestamp changes onto the logger without rebuilding it. +package logging + +import ( + "context" + "fmt" + "io" + "log/slog" + "strings" + "sync/atomic" + "time" +) + +// Config is the subset of *config.C that ApplyConfig reads. Declaring it +// here keeps the logging package from depending on config directly, which +// would cycle through the shared test helpers (test.NewLogger imports +// logging, and config's tests import test). *config.C satisfies this +// interface structurally with no adapter. +type Config interface { + GetString(key, def string) string + GetBool(key string, def bool) bool +} + +// LevelTrace is a custom slog level below Debug, used when logging.level is +// "trace". slog has no builtin trace level; the value is one step below +// slog.LevelDebug in slog's 4-point spacing. +const LevelTrace = slog.Level(-8) + +// NewLogger returns a *slog.Logger whose level, format, and timestamp +// emission can be reconfigured at runtime via ApplyConfig and the SSH debug +// commands. The default configuration is info-level text output so log +// calls made before ApplyConfig runs still produce output. Timestamps +// follow slog's default RFC3339Nano format; set logging.disable_timestamp +// in config to suppress them. +// +// ApplyConfig and the SSH commands discover the reconfig surface via +// structural type-assertion on l.Handler(), so replacement implementations +// (tests, platform-specific sinks) need only implement the subset of +// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)} +// they care about. Callers that pass a plain *slog.Logger without these +// methods get a silent no-op; reconfiguration is always opt-in. +func NewLogger(w io.Writer) *slog.Logger { + return slog.New(NewHandler(w)) +} + +// NewHandler builds the *Handler that NewLogger wraps. Exported for +// platform-specific sinks (notably cmd/nebula-service/logs_windows.go) +// that want to wrap the handler with extra behavior, such as tagging each +// record with its Event Log severity, while still benefiting from all the +// level / format / timestamp / WithAttrs machinery implemented here. +func NewHandler(w io.Writer) *Handler { + root := &handlerRoot{} + root.level.Set(slog.LevelInfo) + opts := &slog.HandlerOptions{Level: &root.level} + return &Handler{ + root: root, + text: slog.NewTextHandler(w, opts), + json: slog.NewJSONHandler(w, opts), + } +} + +// handlerRoot carries the reconfiguration state shared by every logger +// derived from a NewHandler call. All fields are consulted on the log +// path and updated lock-free. +type handlerRoot struct { + level slog.LevelVar + disableTimestamp atomic.Bool + // jsonMode picks which of the pre-derived inner handlers Handler.Handle + // dispatches to. Flipping it propagates instantly to every derived logger + // without rebuilding or chain-replaying anything. + jsonMode atomic.Bool +} + +// Handler is the slog.Handler returned by NewHandler. It holds two +// pre-derived slog handlers -- one text, one json -- both built from the +// same accumulated WithAttrs/WithGroup state. Handle picks which one to +// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes +// effect immediately across the whole process without having to rebuild +// any derived loggers. +type Handler struct { + root *handlerRoot + text slog.Handler + json slog.Handler +} + +func (h *Handler) Enabled(_ context.Context, l slog.Level) bool { + return h.root.level.Level() <= l +} + +func (h *Handler) Handle(ctx context.Context, r slog.Record) error { + if h.root.disableTimestamp.Load() { + r.Time = time.Time{} + } + if h.root.jsonMode.Load() { + return h.json.Handle(ctx, r) + } + return h.text.Handle(ctx, r) +} + +func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return h + } + return &Handler{ + root: h.root, + text: h.text.WithAttrs(attrs), + json: h.json.WithAttrs(attrs), + } +} + +func (h *Handler) WithGroup(name string) slog.Handler { + if name == "" { + return h + } + return &Handler{ + root: h.root, + text: h.text.WithGroup(name), + json: h.json.WithGroup(name), + } +} + +// SetLevel updates the effective log level. Propagates to every derived +// logger via the shared LevelVar. +func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) } + +// GetLevel reports the current log level. +func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() } + +// SetFormat flips the output format atomically. Valid formats are "text" +// and "json". Every derived logger sees the new format on its next Handle +// call; no rebuild or registration is required. +func (h *Handler) SetFormat(format string) error { + switch format { + case "text": + h.root.jsonMode.Store(false) + case "json": + h.root.jsonMode.Store(true) + default: + return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"}) + } + return nil +} + +// GetFormat reports the currently selected format name. +func (h *Handler) GetFormat() string { + if h.root.jsonMode.Load() { + return "json" + } + return "text" +} + +// SetDisableTimestamp toggles whether Handle zeroes r.Time before +// dispatching (slog's builtin text/json handlers skip emitting the time +// attribute on a zero time). +func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) } + +// ApplyConfig reads logging.level, logging.format, and (optionally) +// logging.disable_timestamp from c and applies them to l. The reconfig +// surface is discovered via structural type-assertion on l.Handler(), so +// foreign handlers silently opt out of whichever capabilities they do not +// implement. +// +// nebula.Main does NOT call this function on your behalf; callers that want +// config-driven log level / format / timestamp updates invoke it at +// startup and register it as a reload callback themselves. This keeps the +// library from mutating an embedder's logger without their say-so. +func ApplyConfig(l *slog.Logger, c Config) error { + h := l.Handler() + + lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) + if err != nil { + return err + } + if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok { + ls.SetLevel(lvl) + } + + format := strings.ToLower(c.GetString("logging.format", "text")) + if fs, ok := h.(interface{ SetFormat(string) error }); ok { + if err := fs.SetFormat(format); err != nil { + return err + } + } + + if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok { + ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false)) + } + return nil +} + +// ParseLevel converts a config-string level name ("trace", "debug", "info", +// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and +// "panic" are accepted for backwards compatibility with pre-slog configs +// and both map to slog.LevelError. +func ParseLevel(s string) (slog.Level, error) { + switch s { + case "trace": + return LevelTrace, nil + case "debug": + return slog.LevelDebug, nil + case "info": + return slog.LevelInfo, nil + case "warn", "warning": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + case "fatal", "panic": + return slog.LevelError, nil + default: + return 0, fmt.Errorf("not a valid logging level: %q", s) + } +} + +// LevelName returns a human-readable name for a slog.Level matching the +// strings accepted by ParseLevel. +func LevelName(l slog.Level) string { + switch { + case l <= LevelTrace: + return "trace" + case l <= slog.LevelDebug: + return "debug" + case l <= slog.LevelInfo: + return "info" + case l <= slog.LevelWarn: + return "warn" + default: + return "error" + } +} diff --git a/logging/logger_bench_test.go b/logging/logger_bench_test.go new file mode 100644 index 00000000..eb29c1c3 --- /dev/null +++ b/logging/logger_bench_test.go @@ -0,0 +1,90 @@ +package logging + +import ( + "context" + "io" + "log/slog" + "testing" +) + +// BenchmarkLogger_* compare the handler returned by NewLogger against a +// stock slog text handler. The key thing we care about is the per-log +// cost on a logger that has been derived via .With(), because that is the +// shape subsystems store on their structs (HostInfo.logger(), +// lh.l.With("subsystem", ...), etc.) and call from hot paths. + +func BenchmarkLogger_Stock_RootInfo(b *testing.B) { + l := slog.New(slog.DiscardHandler) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Nebula_RootInfo(b *testing.B) { + l := NewLogger(io.Discard) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) { + l := slog.New(slog.DiscardHandler).With( + "subsystem", "bench", + "localIndex", 1234, + ) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) { + l := NewLogger(io.Discard).With( + "subsystem", "bench", + "localIndex", 1234, + ) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +// Gated-off-path benchmarks: mimic the typical hot-path shape +// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below +// the active level. This is the dominant pattern in inside.go/outside.go and +// what we pay on every packet. +func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) { + l := slog.New(slog.DiscardHandler).With( + "subsystem", "bench", + "localIndex", 1234, + ) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if l.Enabled(ctx, slog.LevelDebug) { + l.Debug("hello", "i", i) + } + } +} + +func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) { + l := NewLogger(io.Discard).With( + "subsystem", "bench", + "localIndex", 1234, + ) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if l.Enabled(ctx, slog.LevelDebug) { + l.Debug("hello", "i", i) + } + } +} diff --git a/main.go b/main.go index 17aaa548..eef13c97 100644 --- a/main.go +++ b/main.go @@ -3,13 +3,13 @@ package nebula import ( "context" "fmt" + "log/slog" "net" "net/netip" "runtime/debug" "strings" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/sshd" @@ -20,7 +20,7 @@ import ( type m = map[string]any -func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { +func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. defer func() { @@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg buildVersion = moduleVersion() } - l := logger - l.Formatter = &logrus.TextFormatter{ - FullTimestamp: true, - } - // Print the config if in test, the exit comes later if configTest { b, err := yaml.Marshal(c.Settings) @@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } // Print the final config - l.Println(string(b)) + l.Info(string(b)) } - err := configLogger(l, c) - if err != nil { - return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) - } - - c.RegisterReloadCallback(func(c *config.C) { - err := configLogger(l, c) - if err != nil { - l.WithError(err).Error("Failed to configure the logger") - } - }) - pki, err := NewPKIFromConfig(l, c) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) @@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } - l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") + l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes()) - ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) + ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) } @@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") + l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err) sshStart = nil } } @@ -99,19 +82,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg routines = 1 } if routines > 1 { - l.WithField("routines", routines).Info("Using multiple routines") + l.Info("Using multiple routines", "routines", routines) } } else { // deprecated and undocumented tunQueues := c.GetInt("tun.routines", 1) udpQueues := c.GetInt("listen.routines", 1) - if tunQueues > udpQueues { - routines = tunQueues - } else { - routines = udpQueues - } + routines = max(tunQueues, udpQueues) if routines != 1 { - l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead") + l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines) } } @@ -124,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg conntrackCacheTimeout = 1 * time.Second } if conntrackCacheTimeout > 0 { - l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") + l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout) } var tun overlay.Device @@ -170,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } for i := 0; i < routines; i++ { - l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port))) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) @@ -219,13 +198,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - serveDns := false - if c.GetBool("lighthouse.serve_dns", false) { - if c.GetBool("lighthouse.am_lighthouse", false) { - serveDns = true - } else { - l.Warn("DNS server refusing to run because this host is not a lighthouse.") - } + ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) + if err != nil { + l.Warn("Failed to start DNS responder", "error", err) } ifConfig := &InterfaceConfig{ @@ -234,7 +209,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg Outside: udpConns[0], pki: pki, Firewall: fw, - ServeDns: serveDns, + DnsServer: ds, HandshakeManager: handshakeManager, connectionManager: connManager, lightHouse: lightHouse, @@ -271,7 +246,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg go handshakeManager.Run(ctx) } - statsStart, err := startStats(l, c, buildVersion, configTest) + stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } @@ -284,23 +259,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg attachCommands(l, c, ssh, ifce) - // Start DNS server last to allow using the nebula IP as lighthouse.dns.host - var dnsStart func() - if lightHouse.amLighthouse && serveDns { - l.Debugln("Starting dns server") - dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) - } - return &Control{ - ifce, - l, - ctx, - cancel, - sshStart, - statsStart, - dnsStart, - lightHouse.StartUpdateWorker, - connManager.Start, + state: StateReady, + f: ifce, + l: l, + ctx: ctx, + cancel: cancel, + sshStart: sshStart, + statsStart: stats.Start, + dnsStart: ds.Start, + lighthouseStart: lightHouse.StartUpdateWorker, + connectionManagerStart: connManager.Start, }, nil } diff --git a/noise.go b/noise.go index 392f0b6b..eae268ed 100644 --- a/noise.go +++ b/noise.go @@ -15,14 +15,12 @@ type endianness interface { var noiseEndianness endianness = binary.BigEndian type NebulaCipherState struct { - c noise.Cipher - //k [32]byte - //n uint64 + c cipher.AEAD } func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { - return &NebulaCipherState{c: s.Cipher()} - + x := s.Cipher() + return &NebulaCipherState{c: x.(cipher.AEAD)} } type cipherAEADDanger interface { @@ -40,25 +38,20 @@ type cipherAEADDanger interface { // be re-used by callers to minimize garbage collection. func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { if s != nil { - switch ce := s.c.(type) { - case cipherAEADDanger: - return ce.EncryptDanger(out, ad, plaintext, n, nb) - default: - // TODO: Is this okay now that we have made messageCounter atomic? - // Alternative may be to split the counter space into ranges - //if n <= s.n { - // return nil, errors.New("CRITICAL: a duplicate counter value was used") - //} - //s.n = n - nb[0] = 0 - nb[1] = 0 - nb[2] = 0 - nb[3] = 0 - noiseEndianness.PutUint64(nb[4:], n) - out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad) - //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) - return out, nil - } + // TODO: Is this okay now that we have made messageCounter atomic? + // Alternative may be to split the counter space into ranges + //if n <= s.n { + // return nil, errors.New("CRITICAL: a duplicate counter value was used") + //} + //s.n = n + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + noiseEndianness.PutUint64(nb[4:], n) + out = s.c.Seal(out, nb, plaintext, ad) + //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) + return out, nil } else { return nil, errors.New("no cipher state available to encrypt") } @@ -66,17 +59,12 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { if s != nil { - switch ce := s.c.(type) { - case cipherAEADDanger: - return ce.DecryptDanger(out, ad, ciphertext, n, nb) - default: - nb[0] = 0 - nb[1] = 0 - nb[2] = 0 - nb[3] = 0 - noiseEndianness.PutUint64(nb[4:], n) - return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad) - } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + noiseEndianness.PutUint64(nb[4:], n) + return s.c.Open(out, nb, ciphertext, ad) } else { return []byte{}, nil } @@ -84,7 +72,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, func (s *NebulaCipherState) Overhead() int { if s != nil { - return s.c.(cipher.AEAD).Overhead() + return s.c.Overhead() } return 0 } diff --git a/notboring.go b/notboring.go index c86b0bc3..f138a0a6 100644 --- a/notboring.go +++ b/notboring.go @@ -1,5 +1,4 @@ //go:build !boringcrypto -// +build !boringcrypto package nebula diff --git a/outside.go b/outside.go index 172c3e83..1e00a0a9 100644 --- a/outside.go +++ b/outside.go @@ -1,15 +1,16 @@ package nebula import ( + "context" "encoding/binary" "errors" + "log/slog" "net/netip" "time" "github.com/google/gopacket/layers" "golang.org/x/net/ipv6" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "golang.org/x/net/ipv4" @@ -24,7 +25,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) + f.l.Info("Error while parsing inbound packet", + "from", via, + "error", err, + "packet", packet, + ) } return } @@ -32,8 +37,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, //l.Error("in packet ", header, packet[HeaderLen:]) if !via.IsRelayed { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Refusing to process double encrypted packet", "from", via) } return } @@ -87,7 +92,10 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + hostinfo.logger(f.l).Error("HostInfo missing remote relay index", + "vpnAddrs", hostinfo.vpnAddrs, + "remoteIndex", h.RemoteIndex, + ) return } @@ -108,7 +116,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, // Find the target HostInfo relay object targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") + hostinfo.logger(f.l).Info("Failed to find target host info by ip", + "relayTo", relay.PeerAddr, + "error", err, + "hostinfo.vpnAddrs", hostinfo.vpnAddrs, + ) return } @@ -124,7 +136,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + hostinfo.logger(f.l).Info("Unexpected target relay state", + "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], + "targetRelayState", targetRelay.State, + ) return } } @@ -138,9 +154,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt lighthouse packet") + hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet", + "error", err, + "from", via, + "packet", packet, + ) return } @@ -157,9 +175,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt test packet") + hostinfo.logger(f.l).Error("Failed to decrypt test packet", + "error", err, + "from", via, + "packet", packet, + ) return } @@ -190,9 +210,17 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if !f.handleEncrypted(ci, via, h) { return } + _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + if err != nil { + hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet", + "error", err, + "from", via, + "packet", packet, + ) + return + } - hostinfo.logger(f.l).WithField("from", via). - Info("Close tunnel received, tearing down.") + hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via) f.closeTunnel(hostinfo) return @@ -204,9 +232,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt Control packet") + hostinfo.logger(f.l).Error("Failed to decrypt Control packet", + "error", err, + "from", via, + "packet", packet, + ) return } @@ -214,7 +244,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via) + } return } @@ -240,20 +272,27 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { if !via.IsRelayed && hostinfo.remote != via.UdpAddr { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") - return - } - - if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). - Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). - Info("Host roamed to new udp ip/port.") + if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote", + "suppressSeconds", RoamingSuppressSeconds, + "udpAddr", hostinfo.remote, + "newAddr", via.UdpAddr, + ) + } + return + } + + hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.", + "udpAddr", hostinfo.remote, + "newAddr", via.UdpAddr, + ) hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote hostinfo.SetRemote(via.UdpAddr) @@ -327,13 +366,29 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { proto := layers.IPProtocol(data[protoAt]) switch proto { - case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: + case layers.IPProtocolESP, layers.IPProtocolNoNextHeader: fp.Protocol = uint8(proto) fp.RemotePort = 0 fp.LocalPort = 0 fp.Fragment = false return nil + case layers.IPProtocolICMPv6: + if dataLen < offset+6 { + return ErrIPv6PacketTooShort + } + fp.Protocol = uint8(proto) + fp.LocalPort = 0 //incoming vs outgoing doesn't matter for icmpv6 + icmptype := data[offset+1] + switch icmptype { + case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply: + fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier + default: + fp.RemotePort = 0 + } + fp.Fragment = false + return nil + case layers.IPProtocolTCP, layers.IPProtocolUDP: if dataLen < offset+4 { return ErrIPv6PacketTooShort @@ -423,34 +478,38 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { // Accounting for a variable header length, do we have enough data for our src/dst tuples? minLen := ihl - if !fp.Fragment && fp.Protocol != firewall.ProtoICMP { - minLen += minFwPacketLen + if !fp.Fragment { + if fp.Protocol == firewall.ProtoICMP { + minLen += minFwPacketLen + 2 + } else { + minLen += minFwPacketLen + } } + if len(data) < minLen { return ErrIPv4InvalidHeaderLength } - // Firewall packets are locally oriented - if incoming { + if incoming { // Firewall packets are locally oriented fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { - fp.RemotePort = 0 - fp.LocalPort = 0 - } else { - fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) - } } else { fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { - fp.RemotePort = 0 - fp.LocalPort = 0 - } else { - fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) - } + } + + if fp.Fragment { + fp.RemotePort = 0 + fp.LocalPort = 0 + } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + fp.LocalPort = 0 //code would be uint16(data[ihl+1]) + } else if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } return nil @@ -464,8 +523,9 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] } if !hostinfo.ConnectionState.window.Update(f.l, mc) { - hostinfo.logger(f.l).WithField("header", h). - Debugln("dropping out of window packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h) + } return nil, errors.New("out of window packet") } @@ -477,20 +537,23 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") + hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err) return false } err = newPacket(out, true, fwPacket) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("packet", out). - Warnf("Error while validating inbound packet") + hostinfo.logger(f.l).Warn("Error while validating inbound packet", + "error", err, + "packet", out, + ) return false } if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - Debugln("dropping out of window packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket) + } return false } @@ -499,10 +562,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping inbound packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping inbound packet", + "fwPacket", fwPacket, + "reason", dropReason, + ) } return false } @@ -510,7 +574,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { - f.l.WithError(err).Error("Failed to write to tun") + f.l.Error("Failed to write to tun", "error", err) } return true } @@ -526,35 +590,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) _ = f.outside.WriteTo(b, endpoint) - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("index", index). - WithField("udpAddr", endpoint). - Debug("Recv error sent") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Recv error sent", + "index", index, + "udpAddr", endpoint, + ) } } func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if !f.acceptRecvErrorConfig.ShouldRecvError(addr) { - f.l.WithField("index", h.RemoteIndex). - WithField("udpAddr", addr). - Debug("Recv error received, ignoring") + f.l.Debug("Recv error received, ignoring", + "index", h.RemoteIndex, + "udpAddr", addr, + ) return } - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("index", h.RemoteIndex). - WithField("udpAddr", addr). - Debug("Recv error received") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Recv error received", + "index", h.RemoteIndex, + "udpAddr", addr, + ) } hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex) if hostinfo == nil { - f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap") + f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex) return } if hostinfo.remote.IsValid() && hostinfo.remote != addr { - f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) + f.l.Info("Someone spoofing recv_errors?", + "addr", addr, + "hostinfoRemote", hostinfo.remote, + ) return } diff --git a/outside_test.go b/outside_test.go index 38dbef62..042ccbb3 100644 --- a/outside_test.go +++ b/outside_test.go @@ -155,6 +155,7 @@ func Test_newPacket_v6(t *testing.T) { // next layer, missing length byte err = newPacket(buffer.Bytes()[:49], true, p) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + err = nil // A good ICMP packet ip = layers.IPv6{ @@ -165,20 +166,26 @@ func Test_newPacket_v6(t *testing.T) { DstIP: net.IPv6linklocalallnodes, } - icmp := layers.ICMPv6{} - - buffer.Clear() - err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp) - if err != nil { - panic(err) + icmp := layers.ICMPv6{ + TypeCode: layers.ICMPv6TypeEchoRequest, + Checksum: 0x1234, } - err = newPacket(buffer.Bytes(), true, p) - require.NoError(t, err) + buffer.Clear() + require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp)) + require.Error(t, newPacket(buffer.Bytes(), true, p)) + + buffer.Clear() + echo := layers.ICMPv6Echo{ + Identifier: 0xabcd, + SeqNumber: 1234, + } + require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp, &echo)) + require.NoError(t, newPacket(buffer.Bytes(), true, p)) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) - assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0xabcd), p.RemotePort) assert.Equal(t, uint16(0), p.LocalPort) assert.False(t, p.Fragment) @@ -574,7 +581,7 @@ func BenchmarkParseV6(b *testing.B) { } evilBytes := buffer.Bytes() - for i := 0; i < 200; i++ { + for range 200 { evilBytes = append(evilBytes, hopHeader...) } evilBytes = append(evilBytes, lastHopHeader...) diff --git a/test/tun.go b/overlay/overlaytest/noop.go similarity index 68% rename from test/tun.go rename to overlay/overlaytest/noop.go index fb32782f..956da7dd 100644 --- a/test/tun.go +++ b/overlay/overlaytest/noop.go @@ -1,4 +1,6 @@ -package test +// Package overlaytest provides fakes of overlay.Device for tests that do +// not want to touch a real tun device or route table. +package overlaytest import ( "errors" @@ -8,6 +10,9 @@ import ( "github.com/slackhq/nebula/routing" ) +// NoopTun is an overlay.Device that silently discards every read and write. +// Useful in tests that need to construct a nebula Interface but do not +// exercise the datapath. type NoopTun struct{} func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { diff --git a/overlay/route.go b/overlay/route.go index 61989581..c6403f91 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -2,6 +2,7 @@ package overlay import ( "fmt" + "log/slog" "math" "net" "net/netip" @@ -9,7 +10,6 @@ import ( "strconv" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) @@ -48,11 +48,14 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { +func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { routeTree := new(bart.Table[routing.Gateways]) for _, r := range routes { if !allowMTU && r.MTU > 0 { - l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) + l.Warn("route MTU is not supported on this platform", + "goos", runtime.GOOS, + "route", r, + ) } gateways := r.Via diff --git a/overlay/route_test.go b/overlay/route_test.go index 9a959a55..f9d9dcd9 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) { routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 2) - routeTree, err := makeRouteTree(l, routes, true) + routeTree, err := makeRouteTree(test.NewLogger(), routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("1.0.0.2") @@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) { routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 3) - routeTree, err := makeRouteTree(l, routes, true) + routeTree, err := makeRouteTree(test.NewLogger(), routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("192.168.86.1") diff --git a/overlay/tun.go b/overlay/tun.go index e0bf69f6..3af1e189 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -2,10 +2,10 @@ package overlay import ( "fmt" + "log/slog" "net" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" ) @@ -22,9 +22,9 @@ func (e *NameError) Error() string { } // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) @@ -36,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return newTunFromFd(c, l, *fd, vpnNetworks) } } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..9cbb64be 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -6,12 +6,12 @@ package overlay import ( "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -23,10 +23,10 @@ type tun struct { vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..524ef0cd 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" @@ -14,7 +15,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -30,7 +30,7 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr - l *logrus.Logger + l *slog.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte @@ -79,7 +79,7 @@ type ifreqAlias6 struct { Lifetime addrLifetime } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error { err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { - t.l.WithField("route", r.Cidr). - Warnf("unable to add unsafe_route, identical route already exists") + t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr) } else { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { @@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error { } } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.linkAddr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index aa3dddaf..f47880dd 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -1,13 +1,14 @@ package overlay import ( + "context" "fmt" "io" + "log/slog" "net/netip" "strings" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/routing" ) @@ -19,10 +20,10 @@ type disabledTun struct { // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter rx metrics.Counter - l *logrus.Logger + l *slog.Logger } -func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun { tun := &disabledTun{ vpnNetworks: vpnNetworks, read: make(chan []byte, queueLen), @@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) { } t.tx.Inc(1) - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Write payload", "raw", prettyPacket(r)) } return copy(b, r), nil @@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { select { case t.read <- out: default: - t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response") + t.l.Debug("tun_disabled: dropped ICMP Echo Reply response") } return true @@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) { // Check for ICMP Echo Request before spending time doing the full parsing if t.handleICMPEchoRequest(b) { - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b)) } - } else if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload") + } else if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b)) } return len(b), nil } diff --git a/overlay/tun_file_linux_test.go b/overlay/tun_file_linux_test.go new file mode 100644 index 00000000..5ab87e05 --- /dev/null +++ b/overlay/tun_file_linux_test.go @@ -0,0 +1,120 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package overlay + +import ( + "errors" + "os" + "sync" + "testing" + "time" + + "golang.org/x/sys/unix" +) + +// newReadPipe returns a read fd. The matching write fd is registered for cleanup. +// The caller takes ownership of the read fd (pass it to newTunFd / newFriend). +func newReadPipe(t *testing.T) int { + t.Helper() + var fds [2]int + if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil { + t.Fatalf("pipe2: %v", err) + } + t.Cleanup(func() { _ = unix.Close(fds[1]) }) + return fds[0] +} + +func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) { + tf, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + t.Cleanup(func() { _ = tf.Close() }) + + done := make(chan error, 1) + go func() { + _, err := tf.Read(make([]byte, 64)) + done <- err + }() + + // Verify Read is actually blocked in poll. + select { + case err := <-done: + t.Fatalf("Read returned before shutdown signal: %v", err) + case <-time.After(50 * time.Millisecond): + } + + if err := tf.wakeForShutdown(); err != nil { + t.Fatalf("wakeForShutdown: %v", err) + } + + select { + case err := <-done: + if !errors.Is(err, os.ErrClosed) { + t.Fatalf("expected os.ErrClosed, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Read did not wake on shutdown") + } +} + +func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) { + parent, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + friend, err := parent.newFriend(newReadPipe(t)) + if err != nil { + _ = parent.Close() + t.Fatalf("newFriend: %v", err) + } + t.Cleanup(func() { + _ = friend.Close() + _ = parent.Close() + }) + + readers := []*tunFile{parent, friend} + errs := make([]error, len(readers)) + var wg sync.WaitGroup + for i, r := range readers { + wg.Add(1) + go func(i int, r *tunFile) { + defer wg.Done() + _, errs[i] = r.Read(make([]byte, 64)) + }(i, r) + } + + time.Sleep(50 * time.Millisecond) + + if err := parent.wakeForShutdown(); err != nil { + t.Fatalf("wakeForShutdown: %v", err) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("readers did not wake") + } + + for i, err := range errs { + if !errors.Is(err, os.ErrClosed) { + t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err) + } + } +} + +func TestTunFile_Close_Idempotent(t *testing.T) { + tf, err := newTunFd(newReadPipe(t)) + if err != nil { + t.Fatalf("newTunFd: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("second Close should be a no-op, got %v", err) + } +} diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 2f65b3a4..3d995553 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -9,15 +9,18 @@ import ( "fmt" "io" "io/fs" + "log/slog" "net/netip" + "os" "sync/atomic" "syscall" "time" "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -92,133 +95,232 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr - l *logrus.Logger - devFd int + l *slog.Logger + + fd int + shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls + shutdownW int // write end of the shutdown pipe; closing this signals shutdown to any blocked reader/writer + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed atomic.Bool +} + +// blockOnRead waits until the tun fd is readable or shutdown has been signaled. +// Returns os.ErrClosed if Close was called. +func (t *tun) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.readPoll[0].Revents + shutdownEvents := t.readPoll[1].Revents + t.readPoll[0].Revents = 0 + t.readPoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (t *tun) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.writePoll[0].Revents + shutdownEvents := t.writePoll[1].Revents + t.writePoll[0].Revents = 0 + t.writePoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil } func (t *tun) Read(to []byte) (int, error) { - // use readv() to read from the tunnel device, to eliminate the need for copying the buffer - if t.devFd < 0 { - return -1, syscall.EINVAL - } - // first 4 bytes is protocol family, in network byte order - head := make([]byte, 4) - - iovecs := []syscall.Iovec{ + var head [4]byte + iovecs := [2]syscall.Iovec{ {&head[0], 4}, {&to[0], uint64(len(to))}, } - - n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) - - var err error - if errno != 0 { - err = syscall.Errno(errno) - } else { - err = nil - } - // fix bytes read number to exclude header - bytesRead := int(n) - if bytesRead < 0 { - return bytesRead, err - } else if bytesRead < 4 { - return 0, err - } else { - return bytesRead - 4, err + for { + n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) + if errno == 0 { + bytesRead := int(n) + if bytesRead < 4 { + return 0, nil + } + return bytesRead - 4, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnRead(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } } } // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { - // use writev() to write to the tunnel device, to eliminate the need for copying the buffer - if t.devFd < 0 { - return -1, syscall.EINVAL - } - if len(from) <= 1 { return 0, syscall.EIO } + ipVer := from[0] >> 4 - var head []byte + var head [4]byte // first 4 bytes is protocol family, in network byte order - if ipVer == 4 { - head = []byte{0, 0, 0, syscall.AF_INET} - } else if ipVer == 6 { - head = []byte{0, 0, 0, syscall.AF_INET6} - } else { + switch ipVer { + case 4: + head[3] = syscall.AF_INET + case 6: + head[3] = syscall.AF_INET6 + default: return 0, fmt.Errorf("unable to determine IP version from packet") } - iovecs := []syscall.Iovec{ + + iovecs := [2]syscall.Iovec{ {&head[0], 4}, {&from[0], uint64(len(from))}, } - - n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) - - var err error - if errno != 0 { - err = syscall.Errno(errno) - } else { - err = nil + for { + n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) + if errno == 0 { + return int(n) - 4, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnWrite(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } } - - return int(n) - 4, err } func (t *tun) Close() error { - if t.devFd >= 0 { - err := syscall.Close(t.devFd) + if t.closed.Swap(true) { + return nil + } + + // Closing the write end of the shutdown pipe causes any blocked Poll to + // return with POLLHUP on the shutdown fd, so readers/writers wake up and + // exit with os.ErrClosed. + if t.shutdownW >= 0 { + _ = unix.Close(t.shutdownW) + t.shutdownW = -1 + } + + if t.fd >= 0 { + if err := unix.Close(t.fd); err != nil { + t.l.Error("Error closing device", "error", err) + } + t.fd = -1 + } + + if t.shutdownR >= 0 { + _ = unix.Close(t.shutdownR) + t.shutdownR = -1 + } + + c := make(chan struct{}) + go func() { + // destroying the interface can block if a read() is still pending. Do this asynchronously. + defer close(c) + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err == nil { + defer syscall.Close(s) + ifreq := ifreqDestroy{Name: t.deviceBytes()} + err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) + } if err != nil { - t.l.WithError(err).Error("Error closing device") + t.l.Error("Error destroying tunnel", "error", err) } - t.devFd = -1 + }() - c := make(chan struct{}) - go func() { - // destroying the interface can block if a read() is still pending. Do this asynchronously. - defer close(c) - s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) - if err == nil { - defer syscall.Close(s) - ifreq := ifreqDestroy{Name: t.deviceBytes()} - err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) - } - if err != nil { - t.l.WithError(err).Error("Error destroying tunnel") - } - }() - - // wait up to 1 second so we start blocking at the ioctl - select { - case <-c: - case <-time.After(1 * time.Second): - } + // wait up to 1 second so we start blocking at the ioctl + select { + case <-c: + case <-time.After(1 * time.Second): } return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var fd int var err error deviceName := c.GetString("tun.dev", "") if deviceName != "" { - fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0) + fd, err = unix.Open("/dev/"+deviceName, os.O_RDWR, 0) } if errors.Is(err, fs.ErrNotExist) || deviceName == "" { // If the device doesn't already exist, request a new one and rename it - fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0) + fd, err = unix.Open("/dev/tun", os.O_RDWR, 0) } if err != nil { return nil, err } + if err = unix.SetNonblock(fd, true); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to set tun device as nonblocking: %w", err) + } + + // Shutdown pipe lets Close wake any reader/writer blocked in Poll. + var pipeFds [2]int + if err = unix.Pipe2(pipeFds[:], unix.O_CLOEXEC|unix.O_NONBLOCK); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to create shutdown pipe: %w", err) + } + shutdownR, shutdownW := pipeFds[0], pipeFds[1] + + closeOnErr := true + defer func() { + if closeOnErr { + _ = unix.Close(fd) + _ = unix.Close(shutdownR) + _ = unix.Close(shutdownW) + } + }() + // Read the name of the interface var name [16]byte arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)} @@ -237,7 +339,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } if ctrlErr != nil { - return nil, err + return nil, ctrlErr } ifName := string(bytes.TrimRight(name[:], "\x00")) @@ -253,8 +355,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } defer syscall.Close(s) - fd := uintptr(s) - var fromName [16]byte var toName [16]byte copy(fromName[:], ifName) @@ -266,7 +366,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } // Set the device name - _ = ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) + _ = ioctl(uintptr(s), syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) } t := &tun{ @@ -274,13 +374,24 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, - devFd: fd, + fd: fd, + shutdownR: shutdownR, + shutdownW: shutdownW, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownR), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownR), Events: unix.POLLIN}, + }, } err = t.reload(c, true) if err != nil { return nil, err } + closeOnErr = false c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) @@ -475,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -490,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.linkAddr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..6bfcbdfb 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "sync" @@ -14,7 +15,6 @@ import ( "syscall" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -25,14 +25,14 @@ type tun struct { vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7e4aa418..c6cfb686 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,8 +4,10 @@ package overlay import ( + "encoding/binary" "fmt" "io" + "log/slog" "net" "net/netip" "os" @@ -16,7 +18,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -24,9 +25,175 @@ import ( "golang.org/x/sys/unix" ) +// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. +// A shared eventfd allows Close to wake all readers blocked in poll. +type tunFile struct { + fd int + shutdownFd int + lastOne bool + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed bool +} + +// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun +func (r *tunFile) newFriend(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + return &tunFile{ + fd: fd, + shutdownFd: r.shutdownFd, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, + }, + }, nil +} + +func newTunFd(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &tunFile{ + fd: fd, + shutdownFd: shutdownFd, + lastOne: true, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + } + + return out, nil +} + +func (r *tunFile) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.readPoll[0].Revents + shutdownEvents := r.readPoll[1].Revents + r.readPoll[0].Revents = 0 + r.readPoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.writePoll[0].Revents + shutdownEvents := r.writePoll[1].Revents + r.writePoll[0].Revents = 0 + r.writePoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) Read(buf []byte) (int, error) { + for { + if n, err := unix.Read(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnRead(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } +} + +func (r *tunFile) Write(buf []byte) (int, error) { + for { + if n, err := unix.Write(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnWrite(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } +} + +func (r *tunFile) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(int(r.readPoll[1].Fd), buf[:]) + return err +} + +func (r *tunFile) Close() error { + if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem + return nil + } + r.closed = true + if r.lastOne { + _ = unix.Close(r.shutdownFd) + } + return unix.Close(r.fd) +} + type tun struct { - io.ReadWriteCloser - fd int + *tunFile + readers []*tunFile + closeLock sync.Mutex Device string vpnNetworks []netip.Prefix MaxMTU int @@ -46,7 +213,7 @@ type tun struct { routesFromSystem map[netip.Prefix]routing.Gateways routesFromSystemLock sync.Mutex - l *logrus.Logger + l *slog.Logger } func (t *tun) Networks() []netip.Prefix { @@ -71,10 +238,8 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - - t, err := newTunGeneric(c, l, file, vpnNetworks) +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { + t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) if err != nil { return nil, err } @@ -84,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -115,6 +280,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu nameStr := c.GetString("tun.dev", "") copy(req.Name[:], nameStr) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + _ = unix.Close(fd) return nil, &NameError{ Name: nameStr, Underlying: err, @@ -122,8 +288,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } name := strings.Trim(string(req.Name[:]), "\x00") - file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, fd, vpnNetworks) if err != nil { return nil, err } @@ -133,10 +298,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. +func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { + tfd, err := newTunFd(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } t := &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), + tunFile: tfd, + readers: []*tunFile{tfd}, + closeLock: sync.Mutex{}, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), @@ -145,8 +317,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n l: l, } - err := t.reload(c, true) - if err != nil { + if err = t.reload(c, true); err != nil { + _ = t.Close() return nil, err } @@ -206,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error { if !initial { if oldMaxMTU != newMaxMTU { t.setMTU() - t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU) + t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU) } if oldDefaultMTU != newDefaultMTU { for i := range t.vpnNetworks { err := t.setDefaultRoute(t.vpnNetworks[i]) if err != nil { - t.l.Warn(err) + t.l.Warn(err.Error()) } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU) } } } @@ -239,6 +411,9 @@ func (t *tun) SupportsMultiqueue() bool { } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + t.closeLock.Lock() + defer t.closeLock.Unlock() + fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -248,12 +423,19 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + _ = unix.Close(fd) return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") + out, err := t.tunFile.newFriend(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } - return file, nil + t.readers = append(t.readers, out) + + return out, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -261,29 +443,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -func (t *tun) Write(b []byte) (int, error) { - var nn int - maximum := len(b) - - for { - n, err := unix.Write(t.fd, b[nn:maximum]) - if n > 0 { - nn += n - } - if nn == len(b) { - return nn, err - } - - if err != nil { - return nn, err - } - - if n == 0 { - return nn, io.ErrUnexpectedEOF - } - } -} - func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) @@ -333,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error { } err = netlink.AddrDel(link, &al[i]) if err != nil { - t.l.WithError(err).Error("failed to remove address from tun address list") + t.l.Error("failed to remove address from tun address list", "error", err) } else { - t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + t.l.Info("removed address not listed in cert(s)", "removed", al[i].String()) } } @@ -379,12 +538,12 @@ func (t *tun) Activate() error { ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { // If we can't set the queue length nebula will still work but it may lead to packet loss - t.l.WithError(err).Error("Failed to set tun tx queue length") + t.l.Error("Failed to set tun tx queue length", "error", err) } const modeNone = 1 if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil { - t.l.WithError(err).Warn("Failed to disable link local address generation") + t.l.Warn("Failed to disable link local address generation", "error", err) } if err = t.addIPs(link); err != nil { @@ -423,7 +582,7 @@ func (t *tun) setMTU() { ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)} if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { // This is currently a non fatal condition because the route table must have the MTU set appropriately as well - t.l.WithError(err).Error("Failed to set tun mtu") + t.l.Error("Failed to set tun mtu", "error", err) } } @@ -446,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { } err := netlink.RouteReplace(&nr) if err != nil { - t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") + t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr) //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` for i := 0; i < 2; i++ { time.Sleep(100 * time.Millisecond) @@ -454,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { if err == nil { break } else { - t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") + t.l.Warn("Failed to set default route MTU, retrying", + "error", err, + "cidr", cidr, + "mtu", t.DefaultMTU, + ) } } if err != nil { @@ -499,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -531,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) { err := netlink.RouteDel(&nr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } } @@ -562,11 +725,11 @@ func (t *tun) watchRoutes() { netlinkOptions := netlink.RouteSubscribeOptions{ ReceiveBufferSize: t.useSystemRoutesBufferSize, ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0, - ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") }, + ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) }, } if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil { - t.l.WithError(err).Errorf("failed to subscribe to system route changes") + t.l.Error("failed to subscribe to system route changes", "error", err) return } @@ -608,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { link, err := netlink.LinkByName(t.Device) if err != nil { - t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") + t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device) return gateways } @@ -620,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { gateways = append(gateways, routing.NewGateway(gwAddr, 1)) } else { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") + t.l.Debug("Ignoring route update, gateway is not in our network", "route", r) } } else { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") + t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r) } } @@ -636,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) } else { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") + t.l.Debug("Ignoring route update, gateway is not in our network", "route", r) } } else { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") + t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r) } } } @@ -671,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { gateways := t.getGatewaysFromRoute(&r.Route) if len(gateways) == 0 { // No gateways relevant to our network, no routing changes required. - t.l.WithField("route", r).Debug("Ignoring route update, no gateways") + t.l.Debug("Ignoring route update, no gateways", "route", r) return } if r.Dst == nil { - t.l.WithField("route", r).Debug("Ignoring route update, no destination address") + t.l.Debug("Ignoring route update, no destination address", "route", r) return } dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) if !ok { - t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") + t.l.Debug("Ignoring route update, invalid destination address", "route", r) return } @@ -693,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { t.routesFromSystemLock.Lock() if r.Type == unix.RTM_NEWROUTE { - t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") + t.l.Info("Adding route", "destination", dst, "via", gateways) t.routesFromSystem[dst] = gateways newTree.Insert(dst, gateways) } else { - t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") + t.l.Info("Removing route", "destination", dst, "via", gateways) delete(t.routesFromSystem, dst) newTree.Delete(dst) } @@ -707,18 +870,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } func (t *tun) Close() error { + t.closeLock.Lock() + defer t.closeLock.Unlock() + if t.routeChan != nil { close(t.routeChan) + t.routeChan = nil } - if t.ReadWriteCloser != nil { - _ = t.ReadWriteCloser.Close() - } + // Signal all readers blocked in poll to wake up and exit + _ = t.tunFile.wakeForShutdown() if t.ioctlFd > 0 { - _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() + _ = unix.Close(int(t.ioctlFd)) t.ioctlFd = 0 } - return nil + for i := range t.readers { + if i == 0 { + continue //we want to close the zeroth reader last + } + err := t.readers[i].Close() + if err != nil { + t.l.Error("error closing tun reader", "reader", i, "error", err) + } else { + t.l.Info("closed tun reader", "reader", i) + } + } + + //this is t.readers[0] too + err := t.tunFile.Close() + if err != nil { + t.l.Error("error closing tun reader", "reader", 0, "error", err) + } else { + t.l.Info("closed tun reader", "reader", 0) + } + return err } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..c971bb6e 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "regexp" @@ -15,7 +16,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -63,18 +63,18 @@ type tun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger f *os.File fd int } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( err = unix.SetNonblock(fd, true) if err != nil { - l.WithError(err).Warn("Failed to set the tun device as nonblocking") + l.Warn("Failed to set the tun device as nonblocking", "error", err) } t := &tun{ @@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..81362184 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "regexp" @@ -15,7 +16,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -54,7 +54,7 @@ type tun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger f *os.File fd int // cache out buffer since we need to prepend 4 bytes for tun metadata @@ -63,11 +63,11 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( err = unix.SetNonblock(fd, true) if err != nil { - l.WithError(err).Warn("Failed to set the tun device as nonblocking") + l.Warn("Failed to set the tun device as nonblocking", "error", err) } t := &tun{ @@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..b2c2a0ea 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -4,14 +4,15 @@ package overlay import ( + "context" "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) @@ -21,14 +22,14 @@ type TestTun struct { vpnNetworks []netip.Prefix Routes []Route routeTree *bart.Table[routing.Gateways] - l *logrus.Logger + l *slog.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err @@ -49,7 +50,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -61,8 +62,8 @@ func (t *TestTun) Send(packet []byte) { return } - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Tun receiving injected packet", "dataLen", len(packet)) } t.rxPackets <- packet } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 223eabee..680dddb3 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -7,6 +7,7 @@ import ( "crypto" "fmt" "io" + "log/slog" "net/netip" "os" "path/filepath" @@ -16,7 +17,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -33,16 +33,16 @@ type winTun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger tun *wintun.NativeTun } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { err := checkWinTunExists() if err != nil { return nil, fmt.Errorf("can not load the wintun driver: %w", err) @@ -71,7 +71,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( if err != nil { // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. // Trying a second time resolves the issue. - l.WithError(err).Debug("Failed to create wintun device, retrying") + l.Debug("Failed to create wintun device, retrying", "error", err) tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { return nil, &NameError{ @@ -170,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } if !foundDefault4 { @@ -208,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error { // See comment on luid.AddRoute err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..e5f27f37 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -2,14 +2,14 @@ package overlay import ( "io" + "log/slog" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return NewUserDevice(vpnNetworks) } diff --git a/pki.go b/pki.go index 19869d58..fb8cc5c6 100644 --- a/pki.go +++ b/pki.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "log/slog" "net" "net/netip" "os" @@ -14,7 +16,6 @@ import ( "time" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" @@ -23,7 +24,7 @@ import ( type PKI struct { cs atomic.Pointer[CertState] caPool atomic.Pointer[cert.CAPool] - l *logrus.Logger + l *slog.Logger } type CertState struct { @@ -45,7 +46,7 @@ type CertState struct { myVpnBroadcastAddrsTable *bart.Lite } -func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { +func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) { pki := &PKI{l: l} err := pki.reload(c, true) if err != nil { @@ -181,9 +182,9 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { p.cs.Store(newState) if initial { - p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") + p.l.Debug("Client nebula certificate(s)", "cert", newState) } else { - p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") + p.l.Info("Client certificate(s) refreshed from disk", "cert", newState) } return nil } @@ -195,7 +196,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { } p.caPool.Store(caPool) - p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") + p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints()) return nil } @@ -486,32 +487,32 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) { return c, b, nil } -func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { - var rawCA []byte - var err error - +func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) { caPathOrPEM := c.GetString("pki.ca", "") if caPathOrPEM == "" { return nil, errors.New("no pki.ca path or PEM data provided") } - if strings.Contains(caPathOrPEM, "-----BEGIN") { - rawCA = []byte(caPathOrPEM) + var caReader io.ReadCloser + var err error + if strings.Contains(caPathOrPEM, "-----BEGIN") { + caReader = io.NopCloser(strings.NewReader(caPathOrPEM)) } else { - rawCA, err = os.ReadFile(caPathOrPEM) + caReader, err = os.Open(caPathOrPEM) if err != nil { return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) } } + defer caReader.Close() - caPool, err := cert.NewCAPoolFromPEM(rawCA) + caPool, err := cert.NewCAPoolFromPEMReader(caReader) if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { if crt.Certificate.Expired(time.Now()) { expired++ - l.WithField("cert", crt).Warn("expired certificate present in CA pool") + l.Warn("expired certificate present in CA pool", "cert", crt) } } @@ -529,7 +530,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { caPool.BlocklistFingerprint(fp) } - l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates") + l.Info("Blocklisted certificates", "fingerprintCount", len(bl)) } return caPool, nil diff --git a/pki_hup_benchmark_test.go b/pki_hup_benchmark_test.go new file mode 100644 index 00000000..bca23d78 --- /dev/null +++ b/pki_hup_benchmark_test.go @@ -0,0 +1,121 @@ +package nebula + +import ( + "bytes" + "fmt" + "net/netip" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + cert_test "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/require" +) + +func BenchmarkReloadConfigWithCAs(b *testing.B) { + prevProcs := runtime.GOMAXPROCS(1) + b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) }) + + for _, size := range []int{100, 250, 500, 1000, 5000} { + b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) { + l := test.NewLogger() + dir := b.TempDir() + + ca, caKey, caBundle := buildCABundle(b, size) + caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle) + + configBody := fmt.Sprintf(`pki: + ca: %s + cert: %s + key: %s +`, caPath, certPath, keyPath) + + configPath := filepath.Join(dir, "config.yml") + require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600)) + + c := config.NewC(l) + require.NoError(b, c.Load(dir)) + + _, err := NewPKIFromConfig(test.NewLogger(), c) + require.NoError(b, err) + + b.ReportAllocs() + b.ResetTimer() + + for b.Loop() { + c.ReloadConfig() + } + }) + } +} + +func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) { + b.Helper() + require.GreaterOrEqual(b, count, 1) + + before := time.Now().Add(-24 * time.Hour) + after := time.Now().Add(24 * time.Hour) + + ca, _, caKey, pem := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + before, + after, + nil, + nil, + nil, + ) + + buf := bytes.NewBuffer(pem) + buf.Write([]byte("\n# a comment!\n")) + + for i := 1; i < count; i++ { + _, _, _, extraPEM := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + time.Now(), + time.Now().Add(time.Hour), + nil, + nil, + nil, + ) + buf.Write([]byte("\n# a comment!\n")) + buf.Write(extraPEM) + } + + return ca, caKey, buf.Bytes() +} + +func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) { + b.Helper() + + networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")} + + _, _, keyPEM, certPEM := cert_test.NewTestCert( + cert.Version2, + cert.Curve_CURVE25519, + ca, + caKey, + "reload-benchmark", + time.Now(), + time.Now().Add(time.Hour), + networks, + nil, + nil, + ) + + caPath := filepath.Join(dir, "ca.pem") + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + require.NoError(b, os.WriteFile(caPath, caBundle, 0o600)) + require.NoError(b, os.WriteFile(certPath, certPEM, 0o600)) + require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600)) + + return caPath, certPath, keyPath +} diff --git a/punchy.go b/punchy.go index 2034405a..6ecf4f85 100644 --- a/punchy.go +++ b/punchy.go @@ -1,10 +1,10 @@ package nebula import ( + "log/slog" "sync/atomic" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) @@ -14,10 +14,10 @@ type Punchy struct { delay atomic.Int64 respondDelay atomic.Int64 punchEverything atomic.Bool - l *logrus.Logger + l *slog.Logger } -func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { +func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { p := &Punchy{l: l} p.reload(c, true) @@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) { p.respond.Store(yes) if !initial { - p.l.Infof("punchy.respond changed to %v", p.GetRespond()) + p.l.Info("punchy.respond changed", "respond", p.GetRespond()) } } @@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) { if initial || c.HasChanged("punchy.delay") { p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) if !initial { - p.l.Infof("punchy.delay changed to %s", p.GetDelay()) + p.l.Info("punchy.delay changed", "delay", p.GetDelay()) } } if initial || c.HasChanged("punchy.target_all_remotes") { p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) if !initial { - p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed") + p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything()) } } if initial || c.HasChanged("punchy.respond_delay") { p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) if !initial { - p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay()) + p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay()) } } } diff --git a/punchy_test.go b/punchy_test.go index 56dd1c25..cbf9b17b 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -1,6 +1,8 @@ package nebula import ( + "context" + "log/slog" "testing" "time" @@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) { c := config.NewC(l) // Test defaults - p := NewPunchyFromConfig(l, c) + p := NewPunchyFromConfig(test.NewLogger(), c) assert.False(t, p.GetPunch()) assert.False(t, p.GetRespond()) assert.Equal(t, time.Second, p.GetDelay()) @@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) { // punchy deprecation c.Settings["punchy"] = true - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetPunch()) // punchy.punch c.Settings["punchy"] = map[string]any{"punch": true} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetPunch()) // punch_back deprecation c.Settings["punch_back"] = true - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetRespond()) // punchy.respond c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetRespond()) // punchy.delay c.Settings["punchy"] = map[string]any{"delay": "1m"} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, time.Minute, p.GetDelay()) // punchy.respond_delay c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, time.Minute, p.GetRespondDelay()) } @@ -62,7 +64,7 @@ punchy: delay: 1m respond: false `)) - p := NewPunchyFromConfig(l, c) + p := NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, delay, p.GetDelay()) assert.False(t, p.GetRespond()) @@ -76,3 +78,158 @@ punchy: assert.Equal(t, newDelay, p.GetDelay()) assert.True(t, p.GetRespond()) } + +// The tests below pin the shape of each log line Punchy produces so changes +// cannot silently break whatever operators are grepping for. The assertions +// are on the structured message + attrs (e.g. "punchy.respond changed" with +// a respond=true field) rather than a formatted string. +// +// Punchy.reload also emits a spurious "Changing punchy.punch with reload is +// not supported" warning whenever any key under punchy changes, because of +// the c.HasChanged("punchy") fallback kept for the deprecated top-level +// punchy form. The tests filter by message rather than asserting total +// entry counts so that warning is tolerated without being locked into +// the format. + +type capturedEntry struct { + Level slog.Level + Msg string + Attrs map[string]any +} + +// capturingHandler is a slog.Handler that records each Record it receives so +// tests can assert on the level, message, and attribute map of individual log +// lines without coupling to any specific text format. +type capturingHandler struct { + entries []capturedEntry +} + +func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true } + +func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error { + e := capturedEntry{ + Level: r.Level, + Msg: r.Message, + Attrs: make(map[string]any), + } + r.Attrs(func(a slog.Attr) bool { + e.Attrs[a.Key] = a.Value.Resolve().Any() + return true + }) + h.entries = append(h.entries, e) + return nil +} + +func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } +func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h } + +func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) { + t.Helper() + hook := &capturingHandler{} + return slog.New(hook), hook +} + +func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry { + t.Helper() + for _, e := range entries { + if e.Msg == msg { + return e + } + } + t.Fatalf("no entry with message %q among %d entries", msg, len(entries)) + return capturedEntry{} +} + +func TestPunchy_LogFormat_InitialEnabled(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: true}`)) + + NewPunchyFromConfig(l, c) + + entry := findEntry(t, hook.entries, "punchy enabled") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_InitialDisabled(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: false}`)) + + NewPunchyFromConfig(l, c) + + entry := findEntry(t, hook.entries, "punchy disabled") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`)) + + entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.") + assert.Equal(t, slog.LevelWarn, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadRespond(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {respond: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`)) + + entry := findEntry(t, hook.entries, "punchy.respond changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"respond": true}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadDelay(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {delay: 1s}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`)) + + entry := findEntry(t, hook.entries, "punchy.delay changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`)) + + entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`)) + + entry := findEntry(t, hook.entries, "punchy.respond_delay changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs) +} diff --git a/relay_manager.go b/relay_manager.go index 5dd355ca..919bb2b6 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -5,22 +5,22 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net/netip" "sync/atomic" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) type relayManager struct { - l *logrus.Logger + l *slog.Logger hostmap *HostMap amRelay atomic.Bool } -func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager { +func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager { rm := &relayManager{ l: l, hostmap: hostmap, @@ -29,7 +29,7 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c c.RegisterReloadCallback(func(c *config.C) { err := rm.reload(c, false) if err != nil { - l.WithError(err).Error("Failed to reload relay_manager") + rm.l.Error("Failed to reload relay_manager", "error", err) } }) return rm @@ -52,10 +52,10 @@ func (rm *relayManager) setAmRelay(v bool) { // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. -func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { +func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() - for i := 0; i < 32; i++ { + for range 32 { index, err := generateIndex(l) if err != nil { return 0, err @@ -92,24 +92,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - fields := logrus.Fields{ - "relay": relayHostInfo.vpnAddrs[0], - "initiatorRelayIndex": m.InitiatorRelayIndex, - } - + var relayFrom, relayTo any if m.RelayFromAddr == nil { - fields["relayFrom"] = m.OldRelayFromAddr + relayFrom = m.OldRelayFromAddr } else { - fields["relayFrom"] = m.RelayFromAddr + relayFrom = m.RelayFromAddr } - if m.RelayToAddr == nil { - fields["relayTo"] = m.OldRelayToAddr + relayTo = m.OldRelayToAddr } else { - fields["relayTo"] = m.RelayToAddr + relayTo = m.RelayToAddr } - rm.l.WithFields(fields).Info("relayManager failed to update relay") + rm.l.Info("relayManager failed to update relay", + "relay", relayHostInfo.vpnAddrs[0], + "initiatorRelayIndex", m.InitiatorRelayIndex, + "relayFrom", relayFrom, + "relayTo", relayTo, + ) return nil, fmt.Errorf("unknown relay") } @@ -120,7 +120,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { msg := &NebulaControl{} err := msg.Unmarshal(d) if err != nil { - h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") + h.logger(f.l).Error("Failed to unmarshal control message", "error", err) return } @@ -147,20 +147,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { } func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { - rm.l.WithFields(logrus.Fields{ - "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), - "relayTo": protoAddrToNetAddr(m.RelayToAddr), - "initiatorRelayIndex": m.InitiatorRelayIndex, - "responderRelayIndex": m.ResponderRelayIndex, - "vpnAddrs": h.vpnAddrs}). - Info("handleCreateRelayResponse") + rm.l.Info("handleCreateRelayResponse", + "relayFrom", protoAddrToNetAddr(m.RelayFromAddr), + "relayTo", protoAddrToNetAddr(m.RelayToAddr), + "initiatorRelayIndex", m.InitiatorRelayIndex, + "responderRelayIndex", m.ResponderRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) target := m.RelayToAddr targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { - rm.l.WithError(err).Error("Failed to update relay for relayTo") + rm.l.Error("Failed to update relay for relayTo", "error", err) return } // Do I need to complete the relays now? @@ -170,12 +170,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f // I'm the middle man. Let the initiator know that the I've established the relay they requested. peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { - rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") + rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr) return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { - rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") + rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0]) return } switch peerRelay.State { @@ -193,12 +193,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f if v == cert.Version1 { peer := peerHostInfo.vpnAddrs[0] if !peer.Is4() { - rm.l.WithField("relayFrom", peer). - WithField("relayTo", target). - WithField("initiatorRelayIndex", resp.InitiatorRelayIndex). - WithField("responderRelayIndex", resp.ResponderRelayIndex). - WithField("vpnAddrs", peerHostInfo.vpnAddrs). - Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address") + rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address", + "relayFrom", peer, + "relayTo", target, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", peerHostInfo.vpnAddrs, + ) return } @@ -213,17 +214,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f msg, err := resp.Marshal() if err != nil { - rm.l.WithError(err). - Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": resp.RelayFromAddr, - "relayTo": resp.RelayToAddr, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnAddrs": peerHostInfo.vpnAddrs}). - Info("send CreateRelayResponse") + rm.l.Info("send CreateRelayResponse", + "relayFrom", resp.RelayFromAddr, + "relayTo", resp.RelayToAddr, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", peerHostInfo.vpnAddrs, + ) } } } @@ -232,17 +232,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f from := protoAddrToNetAddr(m.RelayFromAddr) target := protoAddrToNetAddr(m.RelayToAddr) - logMsg := rm.l.WithFields(logrus.Fields{ - "relayFrom": from, - "relayTo": target, - "initiatorRelayIndex": m.InitiatorRelayIndex, - "vpnAddrs": h.vpnAddrs}) + logMsg := rm.l.With( + "relayFrom", from, + "relayTo", target, + "initiatorRelayIndex", m.InitiatorRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. if f.myVpnAddrsTable.Contains(from) { - logMsg.WithField("myIP", from).Error("Discarding relay request from myself") + logMsg.Error("Discarding relay request from myself", "myIP", from) return } @@ -261,37 +262,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + logMsg.Error("Existing relay mismatch with CreateRelayRequest", + "existingRemoteIndex", existingRelay.RemoteIndex) return } case Disestablished: if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + logMsg.Error("Existing relay mismatch with CreateRelayRequest", + "existingRemoteIndex", existingRelay.RemoteIndex) return } // Mark the relay as 'Established' because it's safe to use again h.relayState.UpdateRelayForByIpState(from, Established) case PeerRequested: // I should never be in this state, because I am terminal, not forwarding. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex, - "state": existingRelay.State}).Error("Unexpected Relay State found") + logMsg.Error("Unexpected Relay State found", + "existingRemoteIndex", existingRelay.RemoteIndex, + "state", existingRelay.State) } } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) if err != nil { - logMsg.WithError(err).Error("Failed to add relay") + logMsg.Error("Failed to add relay", "error", err) return } } relay, ok := h.relayState.QueryRelayForByIp(from) if !ok { - logMsg.WithField("from", from).Error("Relay State not found") + logMsg.Error("Relay State not found", "from", from) return } @@ -313,17 +314,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f msg, err := resp.Marshal() if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": from, - "relayTo": target, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnAddrs": h.vpnAddrs}). - Info("send CreateRelayResponse") + rm.l.Info("send CreateRelayResponse", + "relayFrom", from, + "relayTo", target, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) } return } else { @@ -363,12 +363,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if v == cert.Version1 { if !h.vpnAddrs[0].Is4() { - rm.l.WithField("relayFrom", h.vpnAddrs[0]). - WithField("relayTo", target). - WithField("initiatorRelayIndex", req.InitiatorRelayIndex). - WithField("responderRelayIndex", req.ResponderRelayIndex). - WithField("vpnAddr", target). - Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address") + rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address", + "relayFrom", h.vpnAddrs[0], + "relayTo", target, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddr", target, + ) return } @@ -383,17 +384,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f msg, err := req.Marshal() if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control message to create relay") + logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": h.vpnAddrs[0], - "relayTo": target, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnAddr": target}). - Info("send CreateRelayRequest") + rm.l.Info("send CreateRelayRequest", + "relayFrom", h.vpnAddrs[0], + "relayTo", target, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddr", target, + ) } // Also track the half-created Relay state just received @@ -401,8 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if !ok { _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to allocate a local index for relay") + logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err) return } } diff --git a/remote_list.go b/remote_list.go index 1304fd51..7b95de87 100644 --- a/remote_list.go +++ b/remote_list.go @@ -2,6 +2,7 @@ package nebula import ( "context" + "log/slog" "net" "net/netip" "slices" @@ -10,8 +11,6 @@ import ( "sync" "sync/atomic" "time" - - "github.com/sirupsen/logrus" ) // forEachFunc is used to benefit folks that want to do work inside the lock @@ -66,11 +65,11 @@ type hostnamesResults struct { network string lookupTimeout time.Duration cancelFn func() - l *logrus.Logger + l *slog.Logger ips atomic.Pointer[map[netip.AddrPort]struct{}] } -func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { +func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { r := &hostnamesResults{ hostnames: make([]hostnamePort, len(hostPorts)), network: network, @@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) timeoutCancel() if err != nil { - l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") + l.Error("DNS resolution failed for static_map host", + "hostname", hostPort.name, + "network", r.network, + "error", err, + ) continue } for _, a := range addrs { @@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, } } if different { - l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") + l.Info("DNS results changed for host list", + "origSet", origSet, + "newSet", netipAddrs, + ) r.ips.Store(&netipAddrs) onUpdate() } @@ -404,12 +410,7 @@ func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { - for _, v := range r.badRemotes { - if v == remote { - return true - } - } - return false + return slices.Contains(r.badRemotes, remote) } // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the diff --git a/service/service.go b/service/service.go index fc8ac97a..899e851d 100644 --- a/service/service.go +++ b/service/service.go @@ -44,7 +44,10 @@ type Service struct { } func New(control *nebula.Control) (*Service, error) { - control.Start() + wait, err := control.Start() + if err != nil { + return nil, err + } ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) @@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) { } }) + // Add the nebula wait function to the group so a fatal reader error + // propagates out through errgroup.Wait(). + eg.Go(func() error { + return wait() + }) + return &s, nil } diff --git a/service/service_test.go b/service/service_test.go index c6b87423..4bcc8437 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,11 +10,11 @@ import ( "time" "dario.cat/mergo" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/overlay" "go.yaml.in/yaml/v3" "golang.org/x/sync/errgroup" @@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n panic(err) } - logger := logrus.New() - logger.Out = os.Stdout + logger := logging.NewLogger(os.Stdout) control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { diff --git a/ssh.go b/ssh.go index 9a26c290..3863b5ec 100644 --- a/ssh.go +++ b/ssh.go @@ -6,19 +6,21 @@ import ( "errors" "flag" "fmt" + "log/slog" + "maps" "net" "net/netip" "os" - "reflect" + "path/filepath" "runtime" "runtime/pprof" "sort" "strconv" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/sshd" ) @@ -55,12 +57,12 @@ type sshDeviceInfoFlags struct { Pretty bool } -func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { +func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) { c.RegisterReloadCallback(func(c *config.C) { if c.GetBool("sshd.enabled", false) { sshRun, err := configSSH(l, ssh, c) if err != nil { - l.WithError(err).Error("Failed to reconfigure the sshd") + l.Error("Failed to reconfigure the sshd", "error", err) ssh.Stop() } if sshRun != nil { @@ -76,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { // updates the passed-in SSHServer. On success, it returns a function // that callers may invoke to run the configured ssh server. On // failure, it returns nil, error. -func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { +func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { listen := c.GetString("sshd.listen", "") if listen == "" { return nil, fmt.Errorf("sshd.listen must be provided") @@ -118,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, caAuthorizedKey := range rawCAs { err := ssh.AddTrustedCA(caAuthorizedKey) if err != nil { - l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring") + l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey) continue } } @@ -129,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, rk := range keys { kDef, ok := rk.(map[string]any) if !ok { - l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") + l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk) continue } user, ok := kDef["user"].(string) if !ok { - l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field") + l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk) continue } @@ -144,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro case string: err := ssh.AddAuthorizedKey(user, v) if err != nil { - l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key") + l.Warn("Failed to authorize key", + "error", err, + "sshKeyConfig", rk, + "sshKey", v, + ) continue } @@ -152,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, subK := range v { sk, ok := subK.(string) if !ok { - l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key") + l.Warn("Did not understand ssh key", + "sshKeyConfig", rk, + "sshKey", subK, + ) continue } err := ssh.AddAuthorizedKey(user, sk) if err != nil { - l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key") + l.Warn("Failed to authorize key", + "error", err, + "sshKeyConfig", sk, + ) continue } } default: - l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood") + l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk) } } } else { @@ -176,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro ssh.Stop() runner = func() { if err := ssh.Run(listen); err != nil { - l.WithField("err", err).Warn("Failed to run the SSH server") + l.Warn("Failed to run the SSH server", "error", err) } } } else { @@ -186,7 +198,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return runner, nil } -func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { +func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { + // sandboxDir defaults to a dir in temp. The intention is that end user will + // create this dir as needed. Overriding this config value to "" allows + // writing to anywhere in the system. + defaultDir := filepath.Join(os.TempDir(), "nebula-debug") + sandboxDir := c.GetString("sshd.sandbox_dir", defaultDir) + ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", @@ -245,7 +263,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "start-cpu-profile", ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`", - Callback: sshStartCpuProfile, + Callback: func(fs any, a []string, w sshd.StringWriter) error { + return sshStartCpuProfile(sandboxDir, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -260,7 +280,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "save-heap-profile", ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`", - Callback: sshGetHeapProfile, + Callback: func(fs any, a []string, w sshd.StringWriter) error { + return sshGetHeapProfile(sandboxDir, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -272,7 +294,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "save-mutex-profile", ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`", - Callback: sshGetMutexProfile, + Callback: func(fs any, a []string, w sshd.StringWriter) error { + return sshGetMutexProfile(sandboxDir, fs, a, w) + }, }) ssh.RegisterCommand(&sshd.Command{ @@ -505,13 +529,43 @@ func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) er return nil } -func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error { +// sshSanitizeFilePath validates that the given file path is within the sandbox directory. +// If sandboxDir is empty, the path is returned as-is for backwards compatibility. +func sshSanitizeFilePath(sandboxDir, filePath string) (string, error) { + if sandboxDir == "" { + return filePath, nil + } + + // Clean and resolve the path relative to the sandbox directory + if !filepath.IsAbs(filePath) { + filePath = filepath.Join(sandboxDir, filePath) + } + cleaned := filepath.Clean(filePath) + + // Ensure the resolved path is within the sandbox directory + cleanedSandbox := filepath.Clean(sandboxDir) + if cleaned == cleanedSandbox { + return "", fmt.Errorf("path %q resolves to the sandbox directory itself %q", filePath, sandboxDir) + } + if !strings.HasPrefix(cleaned, cleanedSandbox+string(filepath.Separator)) { + return "", fmt.Errorf("path %q is outside the sandbox directory %q", filePath, sandboxDir) + } + + return cleaned, nil +} + +func sshStartCpuProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { err := w.WriteLine("No path to write profile provided") return err } - file, err := os.Create(a[0]) + filePath, err := sshSanitizeFilePath(sandboxDir, a[0]) + if err != nil { + return w.WriteLine(err.Error()) + } + + file, err := os.Create(filePath) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) return err @@ -675,12 +729,17 @@ func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) e return w.WriteLine("Changed") } -func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error { +func sshGetHeapProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } - file, err := os.Create(a[0]) + filePath, err := sshSanitizeFilePath(sandboxDir, a[0]) + if err != nil { + return w.WriteLine(err.Error()) + } + + file, err := os.Create(filePath) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) return err @@ -711,12 +770,17 @@ func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate)) } -func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { +func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } - file, err := os.Create(a[0]) + filePath, err := sshSanitizeFilePath(sandboxDir, a[0]) + if err != nil { + return w.WriteLine(err.Error()) + } + + file, err := os.Create(filePath) if err != nil { return w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) } @@ -735,36 +799,45 @@ func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) } -func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error { + ctrl, ok := l.Handler().(interface { + GetLevel() slog.Level + SetLevel(slog.Level) + }) + if !ok { + return w.WriteLine("Log level is not reconfigurable on this logger") + } + if len(a) == 0 { - return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) + return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel()))) } - level, err := logrus.ParseLevel(a[0]) + level, err := logging.ParseLevel(strings.ToLower(a[0])) if err != nil { - return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels)) + return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a)) } - l.SetLevel(level) - return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) + ctrl.SetLevel(level) + return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel()))) } -func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error { + ctrl, ok := l.Handler().(interface { + GetFormat() string + SetFormat(string) error + }) + if !ok { + return w.WriteLine("Log format is not reconfigurable on this logger") + } + if len(a) == 0 { - return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) + return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat())) } - logFormat := strings.ToLower(a[0]) - switch logFormat { - case "text": - l.Formatter = &logrus.TextFormatter{} - case "json": - l.Formatter = &logrus.JSONFormatter{} - default: - return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) + if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil { + return err } - - return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) + return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat())) } func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { @@ -831,9 +904,7 @@ func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) er relays := map[uint32]*HostInfo{} ifce.hostMap.Lock() - for k, v := range ifce.hostMap.Relays { - relays[k] = v - } + maps.Copy(relays, ifce.hostMap.Relays) ifce.hostMap.Unlock() type RelayFor struct { diff --git a/sshd/server.go b/sshd/server.go index a8b60ba7..38886e53 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -2,19 +2,19 @@ package sshd import ( "bytes" + "context" "errors" "fmt" + "log/slog" "net" - "sync" "github.com/armon/go-radix" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) type SSHServer struct { config *ssh.ServerConfig - l *logrus.Entry + l *slog.Logger certChecker *ssh.CertChecker @@ -27,20 +27,21 @@ type SSHServer struct { commands *radix.Tree listener net.Listener - // Locks the conns/counter to avoid concurrent map access - connsLock sync.Mutex - conns map[int]*session - counter int + // Call the cancel() function to stop all active sessions + ctx context.Context + cancel func() } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen -func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { +func NewSSHServer(l *slog.Logger) (*SSHServer, error) { + ctx, cancel := context.WithCancel(context.Background()) s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, commands: radix.New(), - conns: make(map[int]*session), + ctx: ctx, + cancel: cancel, } cc := ssh.CertChecker{ @@ -120,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error { } s.trustedCAs = append(s.trustedCAs, pk) - s.l.WithField("sshKey", pubKey).Info("Trusted CA key") + s.l.Info("Trusted CA key", "sshKey", pubKey) return nil } @@ -138,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error { } tk[string(pk.Marshal())] = true - s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key") + s.l.Info("Authorized ssh key", + "sshKey", pubKey, + "sshUser", user, + ) return nil } @@ -155,7 +159,7 @@ func (s *SSHServer) Run(addr string) error { return err } - s.l.WithField("sshListener", addr).Info("SSH server is listening") + s.l.Info("SSH server is listening", "sshListener", addr) // Run loops until there is an error s.run() @@ -171,48 +175,54 @@ func (s *SSHServer) run() { c, err := s.listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { - s.l.WithError(err).Warn("Error in listener, shutting down") + s.l.Warn("Error in listener, shutting down", "error", err) } return } - - conn, chans, reqs, err := ssh.NewServerConn(c, s.config) - fp := "" - if conn != nil { - fp = conn.Permissions.Extensions["fp"] - } - - if err != nil { - l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) + go func(c net.Conn) { + // NewServerConn may block while waiting for the client to complete the handshake. + // Ensure that a bad client doesn't hurt us by checking for the parent context + // cancellation before calling NewServerConn, and forcing the socket to close when + // the context is cancelled. + sessionContext, sessionCancel := context.WithCancel(s.ctx) + go func() { + <-sessionContext.Done() + c.Close() + }() + conn, chans, reqs, err := ssh.NewServerConn(c, s.config) + fp := "" if conn != nil { - l = l.WithField("sshUser", conn.User()) - conn.Close() + fp = conn.Permissions.Extensions["fp"] } - if fp != "" { - l = l.WithField("sshFingerprint", fp) + + if err != nil { + l := s.l.With( + "error", err, + "remoteAddress", c.RemoteAddr(), + ) + if conn != nil { + l = l.With("sshUser", conn.User()) + conn.Close() + } + if fp != "" { + l = l.With("sshFingerprint", fp) + } + l.Warn("failed to handshake") + sessionCancel() + return } - l.Warn("failed to handshake") - continue - } - l := s.l.WithField("sshUser", conn.User()) - l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") + l := s.l.With("sshUser", conn.User()) + l.Info("ssh user logged in", + "remoteAddress", c.RemoteAddr(), + "sshFingerprint", fp, + ) - session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) - s.connsLock.Lock() - s.counter++ - counter := s.counter - s.conns[counter] = session - s.connsLock.Unlock() + NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session")) - go ssh.DiscardRequests(reqs) - go func() { - <-session.exitChan - s.l.WithField("id", counter).Debug("closing conn") - s.connsLock.Lock() - delete(s.conns, counter) - s.connsLock.Unlock() - }() + go ssh.DiscardRequests(reqs) + + }(c) } } @@ -220,15 +230,11 @@ func (s *SSHServer) Stop() { // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil { if err := s.listener.Close(); err != nil { - s.l.WithError(err).Warn("Failed to close the sshd listener") + s.l.Warn("Failed to close the sshd listener", "error", err) } } } func (s *SSHServer) closeSessions() { - s.connsLock.Lock() - for _, c := range s.conns { - c.Close() - } - s.connsLock.Unlock() + s.cancel() } diff --git a/sshd/session.go b/sshd/session.go index 87cc216f..1c8e1a9b 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -2,30 +2,30 @@ package sshd import ( "fmt" + "log/slog" "sort" "strings" "github.com/anmitsu/go-shlex" "github.com/armon/go-radix" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "golang.org/x/term" ) type session struct { - l *logrus.Entry + l *slog.Logger c *ssh.ServerConn term *term.Terminal commands *radix.Tree - exitChan chan bool + cancel func() } -func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session { +func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *slog.Logger) *session { s := &session{ commands: radix.NewFromMap(commands.ToMap()), l: l, c: conn, - exitChan: make(chan bool), + cancel: cancel, } s.commands.Insert("logout", &Command{ @@ -42,16 +42,17 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New } func (s *session) handleChannels(chans <-chan ssh.NewChannel) { + defer s.Close() for newChannel := range chans { if newChannel.ChannelType() != "session" { - s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") + s.l.Error("unknown channel type", "sshChannelType", newChannel.ChannelType()) newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } channel, requests, err := newChannel.Accept() if err != nil { - s.l.WithError(err).Warn("could not accept channel") + s.l.Warn("could not accept channel", "error", err) continue } @@ -94,13 +95,12 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { return default: - s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request") + s.l.Debug("Rejected unknown request", "sshRequest", req.Type) err = req.Reply(false, nil) } if err != nil { - s.l.WithError(err).Info("Error handling ssh session requests") - s.Close() + s.l.Info("Error handling ssh session requests", "error", err) return } } @@ -123,12 +123,11 @@ func (s *session) createTerm(channel ssh.Channel) *term.Terminal { return "", 0, false } - go s.handleInput(channel) + go s.handleInput() return term } -func (s *session) handleInput(channel ssh.Channel) { - defer s.Close() +func (s *session) handleInput() { w := &stringWriter{w: s.term} for { line, err := s.term.ReadLine() @@ -170,10 +169,9 @@ func (s *session) dispatchCommand(line string, w StringWriter) { } _ = execCommand(c, args[1:], w) - return } func (s *session) Close() { s.c.Close() - s.exitChan <- true + s.cancel() } diff --git a/stats.go b/stats.go index c88c45cc..97ce7cf5 100644 --- a/stats.go +++ b/stats.go @@ -1,13 +1,16 @@ package nebula import ( + "context" "errors" "fmt" - "log" + "log/slog" "net" "net/http" "runtime" "strconv" + "sync" + "sync/atomic" "time" graphite "github.com/cyberdelia/go-metrics-graphite" @@ -15,113 +18,350 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) -// startStats initializes stats from config. On success, if any further work -// is needed to serve stats, it returns a func to handle that work. If no -// work is needed, it'll return nil. On failure, it returns nil, error. -func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { - mType := c.GetString("stats.type", "") - if mType == "" || mType == "none" { - return nil, nil - } +// statsServer owns nebula's stats subsystem: the periodic metric capture +// goroutine and (for prometheus) an HTTP listener. It mirrors the lifecycle +// shape of dnsServer: constructor wires the reload callback, reload records +// config, Start builds and runs the runtime, Stop tears it down. +type statsServer struct { + l *slog.Logger + ctx context.Context + buildVersion string + configTest bool - interval := c.GetDuration("stats.interval", 0) - if interval == 0 { - return nil, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) - } + // enabled mirrors "stats configured to a real backend". Start consults + // it so callers don't need to know the gating rules. + enabled atomic.Bool - var startFn func() - switch mType { - case "graphite": - err := startGraphiteStats(l, interval, c, configTest) - if err != nil { - return nil, err - } - case "prometheus": - var err error - startFn, err = startPrometheusStats(l, interval, c, buildVersion, configTest) - if err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("stats.type was not understood: %s", mType) - } - - metrics.RegisterDebugGCStats(metrics.DefaultRegistry) - metrics.RegisterRuntimeMemStats(metrics.DefaultRegistry) - - go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval) - go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval) - - return startFn, nil + runMu sync.Mutex + runCfg *statsConfig + run *statsRuntime // non-nil while a runtime is live } -func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error { - proto := c.GetString("stats.protocol", "tcp") - host := c.GetString("stats.host", "") - if host == "" { - return errors.New("stats.host can not be empty") +// statsRuntime is the live state owned by a single Start invocation. Start +// stashes a pointer under runMu; Stop and Start's own exit path use pointer +// equality to tell "my runtime" apart from one that replaced it after a +// reload. +type statsRuntime struct { + cancel context.CancelFunc + listener *http.Server // nil for graphite +} + +// statsConfig is the snapshot of stats-related config that drives the runtime. +// It is comparable with == so reload can detect "no change" cheaply. +type statsConfig struct { + typ string + interval time.Duration + graphite graphiteConfig + prom promConfig +} + +type graphiteConfig struct { + protocol string + host string + // resolvedAddr is the string form of host resolved at config-load time. + // Including it in the struct means a SIGHUP picks up DNS changes even + // when stats.host hasn't been edited. + resolvedAddr string + prefix string +} + +type promConfig struct { + listen string + path string + namespace string + subsystem string +} + +// newStatsServerFromConfig builds a statsServer, applies the initial config, +// and registers a reload callback. The reload callback is registered before +// the initial config is applied so a SIGHUP can later enable, fix, or disable +// stats even if the initial application failed. +// +// Start is safe to call unconditionally: it no-ops when stats are disabled. +// The returned pointer is always non-nil, even on error. +func newStatsServerFromConfig(ctx context.Context, l *slog.Logger, c *config.C, buildVersion string, configTest bool) (*statsServer, error) { + s := &statsServer{ + l: l, + ctx: ctx, + buildVersion: buildVersion, + configTest: configTest, } - prefix := c.GetString("stats.prefix", "nebula") - addr, err := net.ResolveTCPAddr(proto, host) + c.RegisterReloadCallback(func(c *config.C) { + if err := s.reload(c, false); err != nil { + s.l.Error("Failed to reload stats from config", "error", err) + } + }) + + if err := s.reload(c, true); err != nil { + return s, err + } + return s, nil +} + +// reload records the latest config. On the initial call it only records it; +// Control.Start is what launches the first runtime via statsStart. On later +// calls it reconciles the running runtime with the new config: +// +// - newly enabled -> spawn Start +// - newly disabled -> Stop the runtime +// - config changed (still enabled) -> Stop the old, Start the new +// - no change -> no-op +func (s *statsServer) reload(c *config.C, initial bool) error { + newCfg, err := loadStatsConfig(c) if err != nil { - return fmt.Errorf("error while setting up graphite sink: %s", err) + return err + } + enabled := newCfg.typ != "" && newCfg.typ != "none" + + s.runMu.Lock() + sameCfg := s.runCfg != nil && *s.runCfg == newCfg + s.runCfg = &newCfg + running := s.run != nil + s.runMu.Unlock() + + s.enabled.Store(enabled) + + if initial || sameCfg { + return nil } - if !configTest { - l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) - go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) + if running { + s.Stop() + } + if enabled && !s.configTest { + go s.Start() } return nil } -func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { - namespace := c.GetString("stats.namespace", "") - subsystem := c.GetString("stats.subsystem", "") - - listen := c.GetString("stats.listen", "") - if listen == "" { - return nil, fmt.Errorf("stats.listen should not be empty") +// Start builds the runtime from the latest config, spawns the capture loop, +// and blocks until Stop is called or ctx fires. For prometheus it also serves +// the HTTP listener. For graphite it blocks on the capture loop's context. +// Safe to call when stats are disabled or already running (both no-op). +func (s *statsServer) Start() { + if !s.enabled.Load() || s.configTest { + return } - path := c.GetString("stats.path", "") - if path == "" { - return nil, fmt.Errorf("stats.path should not be empty") + s.runMu.Lock() + if s.ctx.Err() != nil || s.run != nil || s.runCfg == nil { + s.runMu.Unlock() + return + } + cfg := *s.runCfg + captureFns, listener := s.buildRuntime(cfg) + runCtx, cancel := context.WithCancel(s.ctx) + rt := &statsRuntime{cancel: cancel, listener: listener} + s.run = rt + s.runMu.Unlock() + + go captureStatsLoop(runCtx, cfg.interval, captureFns) + + cleanExit := true + if listener == nil { + // Graphite: no HTTP listener to serve; block until teardown. + <-runCtx.Done() + } else { + cleanExit = s.serveListener(listener) } - pr := prometheus.NewRegistry() - pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i) - if !configTest { - go pClient.UpdatePrometheusMetrics() - } - - // Export our version information as labels on a static gauge - g := prometheus.NewGauge(prometheus.GaugeOpts{ - Namespace: namespace, - Subsystem: subsystem, - Name: "info", - Help: "Version information for the Nebula binary", - ConstLabels: prometheus.Labels{ - "version": buildVersion, - "goversion": runtime.Version(), - "boringcrypto": strconv.FormatBool(boringEnabled()), - }, - }) - pr.MustRegister(g) - g.Set(1) - - var startFn func() - if !configTest { - startFn = func() { - l.Infof("Prometheus stats listening on %s at %s", listen, path) - http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) - log.Fatal(http.ListenAndServe(listen, nil)) + // Clear our runtime only if nothing has replaced it. Stop races through + // here too but leaves s.run == nil, so the pointer check skips. + s.runMu.Lock() + if s.run == rt { + rt.cancel() + s.run = nil + // A listener that exited with an error (e.g., bind conflict) leaves + // runCfg cached as if it were applied. Drop it so a SIGHUP with the + // same config re-triggers Start once the user fixes the underlying + // problem. + if !cleanExit { + s.runCfg = nil } } - - return startFn, nil + s.runMu.Unlock() +} + +// serveListener runs ListenAndServe and ensures ctx cancellation unblocks it. +// Returns true if the listener exited cleanly (Stop, ctx cancellation, or any +// other http.ErrServerClosed path), false on an unexpected error. +func (s *statsServer) serveListener(listener *http.Server) bool { + // Per-invocation watcher: ctx cancellation triggers a listener shutdown + // which in turn unblocks ListenAndServe. Closing `done` on exit keeps + // the watcher from outliving this call. + done := make(chan struct{}) + go func() { + select { + case <-s.ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := listener.Shutdown(shutdownCtx); err != nil { + s.l.Warn("Failed to shut down prometheus stats listener", "error", err) + } + case <-done: + } + }() + defer close(done) + + s.l.Info("Starting prometheus stats listener", "addr", listener.Addr) + err := listener.ListenAndServe() + if err == nil || errors.Is(err, http.ErrServerClosed) { + return true + } + s.l.Error("Prometheus stats listener exited", "error", err) + return false +} + +// Stop tears down the active runtime, if any. Idempotent. +func (s *statsServer) Stop() { + s.runMu.Lock() + rt := s.run + s.run = nil + s.runMu.Unlock() + if rt == nil { + return + } + rt.cancel() + if rt.listener != nil { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := rt.listener.Shutdown(shutdownCtx); err != nil { + s.l.Warn("Failed to shut down prometheus stats listener", "error", err) + } + cancel() + } +} + +// buildRuntime produces the capture functions and, for prometheus, an un-served +// http.Server from cfg. cfg has already been validated by loadStatsConfig. +func (s *statsServer) buildRuntime(cfg statsConfig) ([]func(), *http.Server) { + // rcrowley/go-metrics guards these registrations with a private sync.Once, + // so subsequent reloads are no-ops. + metrics.RegisterDebugGCStats(metrics.DefaultRegistry) + metrics.RegisterRuntimeMemStats(metrics.DefaultRegistry) + + captureFns := []func(){ + func() { metrics.CaptureDebugGCStatsOnce(metrics.DefaultRegistry) }, + func() { metrics.CaptureRuntimeMemStatsOnce(metrics.DefaultRegistry) }, + } + + switch cfg.typ { + case "graphite": + // loadStatsConfig already resolved and validated the address; re-parse + // the resolved form (no DNS lookup) to get a *net.TCPAddr. + addr, _ := net.ResolveTCPAddr(cfg.graphite.protocol, cfg.graphite.resolvedAddr) + gcfg := graphite.Config{ + Addr: addr, + Registry: metrics.DefaultRegistry, + FlushInterval: cfg.interval, + DurationUnit: time.Nanosecond, + Prefix: cfg.graphite.prefix, + Percentiles: []float64{0.5, 0.75, 0.95, 0.99, 0.999}, + } + captureFns = append(captureFns, func() { + if err := graphite.Once(gcfg); err != nil { + s.l.Error("Graphite export failed", "error", err) + } + }) + s.l.Info("Starting graphite stats", + "interval", cfg.interval, + "prefix", cfg.graphite.prefix, + "addr", addr, + ) + return captureFns, nil + + case "prometheus": + pr := prometheus.NewRegistry() + pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, cfg.prom.namespace, cfg.prom.subsystem, pr, cfg.interval) + captureFns = append(captureFns, func() { + if err := pClient.UpdatePrometheusMetricsOnce(); err != nil { + s.l.Error("Prometheus metrics update failed", "error", err) + } + }) + + g := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: cfg.prom.namespace, + Subsystem: cfg.prom.subsystem, + Name: "info", + Help: "Version information for the Nebula binary", + ConstLabels: prometheus.Labels{ + "version": s.buildVersion, + "goversion": runtime.Version(), + "boringcrypto": strconv.FormatBool(boringEnabled()), + }, + }) + pr.MustRegister(g) + g.Set(1) + + // promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger, + // so bridge our slog.Logger back to a *log.Logger that emits at Error. + errLog := slog.NewLogLogger(s.l.Handler(), slog.LevelError) + mux := http.NewServeMux() + mux.Handle(cfg.prom.path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog})) + return captureFns, &http.Server{Addr: cfg.prom.listen, Handler: mux} + } + return captureFns, nil +} + +// captureStatsLoop runs each fn on every tick of d until ctx is cancelled. +func captureStatsLoop(ctx context.Context, d time.Duration, fns []func()) { + t := time.NewTicker(d) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + for _, fn := range fns { + fn() + } + } + } +} + +func loadStatsConfig(c *config.C) (statsConfig, error) { + cfg := statsConfig{ + typ: c.GetString("stats.type", ""), + } + if cfg.typ == "" || cfg.typ == "none" { + return cfg, nil + } + + cfg.interval = c.GetDuration("stats.interval", 0) + if cfg.interval == 0 { + return cfg, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) + } + + switch cfg.typ { + case "graphite": + cfg.graphite.protocol = c.GetString("stats.protocol", "tcp") + cfg.graphite.host = c.GetString("stats.host", "") + if cfg.graphite.host == "" { + return cfg, errors.New("stats.host can not be empty") + } + addr, err := net.ResolveTCPAddr(cfg.graphite.protocol, cfg.graphite.host) + if err != nil { + return cfg, fmt.Errorf("error while setting up graphite sink: %s", err) + } + cfg.graphite.resolvedAddr = addr.String() + cfg.graphite.prefix = c.GetString("stats.prefix", "nebula") + case "prometheus": + cfg.prom.listen = c.GetString("stats.listen", "") + if cfg.prom.listen == "" { + return cfg, errors.New("stats.listen should not be empty") + } + cfg.prom.path = c.GetString("stats.path", "") + if cfg.prom.path == "" { + return cfg, errors.New("stats.path should not be empty") + } + cfg.prom.namespace = c.GetString("stats.namespace", "") + cfg.prom.subsystem = c.GetString("stats.subsystem", "") + default: + return cfg, fmt.Errorf("stats.type was not understood: %s", cfg.typ) + } + + return cfg, nil } diff --git a/stats_test.go b/stats_test.go new file mode 100644 index 00000000..20b17c0e --- /dev/null +++ b/stats_test.go @@ -0,0 +1,410 @@ +package nebula + +import ( + "context" + "io" + "log/slog" + "net" + "strconv" + "testing" + "time" + + "github.com/slackhq/nebula/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestStatsServer(t *testing.T) (*statsServer, *config.C) { + t.Helper() + l := slog.New(slog.DiscardHandler) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return &statsServer{ + l: l, + ctx: ctx, + }, config.NewC(l) +} + +func setStatsConfig(c *config.C, m map[string]any) { + c.Settings["stats"] = m +} + +func currentRuntime(s *statsServer) *statsRuntime { + s.runMu.Lock() + defer s.runMu.Unlock() + return s.run +} + +func TestStatsServer_reload_initial_disabled(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{"type": "none"}) + + require.NoError(t, s.reload(c, true)) + assert.False(t, s.enabled.Load()) + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_initial_invalidInterval(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "graphite", + "host": "127.0.0.1:0", + "prefix": "test", + }) + + err := s.reload(c, true) + require.Error(t, err) + assert.False(t, s.enabled.Load()) +} + +func TestStatsServer_reload_initial_unknownType(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "carbon", + "interval": "1s", + }) + + err := s.reload(c, true) + require.Error(t, err) + assert.False(t, s.enabled.Load()) +} + +func TestStatsServer_reload_unchanged_noOp(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{"type": "none"}) + + require.NoError(t, s.reload(c, true)) + require.NoError(t, s.reload(c, false)) + assert.False(t, s.enabled.Load()) +} + +func TestStatsServer_reload_initial_graphite(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "graphite", + "interval": "1s", + "protocol": "tcp", + "host": "127.0.0.1:2003", + "prefix": "test", + }) + + require.NoError(t, s.reload(c, true)) + assert.True(t, s.enabled.Load()) + // reload only records config; Start builds the runtime. + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_initial_prometheus(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:0", + "path": "/metrics", + }) + + require.NoError(t, s.reload(c, true)) + assert.True(t, s.enabled.Load()) + // reload only records config; Start builds the runtime. + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_Start_graphite_blocksUntilStop(t *testing.T) { + sink := newGraphiteSink(t) + defer sink.Close() + + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "graphite", + "interval": "1s", + "protocol": "tcp", + "host": sink.Addr(), + "prefix": "test", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + + // Wait for Start to publish runtime state. + waitFor(t, func() bool { return currentRuntime(s) != nil }) + rt := currentRuntime(s) + require.NotNil(t, rt) + assert.Nil(t, rt.listener, "graphite has no listener") + + s.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("graphite Start did not return after Stop") + } + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_StartStop_lifecycle(t *testing.T) { + port := freeTCPPort(t) + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + + waitForListening(t, "127.0.0.1:"+port) + rt := currentRuntime(s) + require.NotNil(t, rt) + require.NotNil(t, rt.listener) + + s.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_disable_stopsRunningRuntime(t *testing.T) { + port := freeTCPPort(t) + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + waitForListening(t, "127.0.0.1:"+port) + + setStatsConfig(c, map[string]any{"type": "none"}) + require.NoError(t, s.reload(c, false)) + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after reload disabled stats") + } + assert.False(t, s.enabled.Load()) + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_reload_changeListener_restartsListener(t *testing.T) { + port1 := freeTCPPort(t) + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port1, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + firstDone := make(chan struct{}) + go func() { + s.Start() + close(firstDone) + }() + waitForListening(t, "127.0.0.1:"+port1) + first := currentRuntime(s) + require.NotNil(t, first) + + port2 := freeTCPPort(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port2, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, false)) + + select { + case <-firstDone: + case <-time.After(5 * time.Second): + t.Fatal("old Start did not return after reload") + } + + waitForListening(t, "127.0.0.1:"+port2) + second := currentRuntime(s) + require.NotNil(t, second) + assert.NotSame(t, first, second, "expected a new runtime after listen address change") + + s.Stop() +} + +func TestStatsServer_Stop_beforeStart_doesNotBlock(t *testing.T) { + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:0", + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + stopped := make(chan struct{}) + go func() { + s.Stop() + close(stopped) + }() + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatal("Stop hung with no runtime started") + } +} + +func TestStatsServer_configTest_validatesWithoutSpawning(t *testing.T) { + s, c := newTestStatsServer(t) + s.configTest = true + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:0", + "path": "/metrics", + }) + + require.NoError(t, s.reload(c, true)) + s.Start() + assert.Nil(t, currentRuntime(s)) +} + +func TestStatsServer_ctxCancel_unblocksStart(t *testing.T) { + // Ensures ctx cancellation alone (no explicit Stop) tears down both + // graphite and prom Start invocations. + port := freeTCPPort(t) + l := slog.New(slog.DiscardHandler) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := &statsServer{l: l, ctx: ctx} + c := config.NewC(l) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + waitForListening(t, "127.0.0.1:"+port) + + cancel() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after ctx cancel") + } +} + +func TestStatsServer_listenerBindFailure_sameCfgReloadRetries(t *testing.T) { + // Hold the port so ListenAndServe will fail on first Start. + blocker, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := strconv.Itoa(blocker.Addr().(*net.TCPAddr).Port) + + s, c := newTestStatsServer(t) + setStatsConfig(c, map[string]any{ + "type": "prometheus", + "interval": "1s", + "listen": "127.0.0.1:" + port, + "path": "/metrics", + }) + require.NoError(t, s.reload(c, true)) + + done := make(chan struct{}) + go func() { + s.Start() + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after bind failure") + } + // Bind failure should have dropped the cached config so a same-cfg + // SIGHUP can retry. + s.runMu.Lock() + cfgAfterFailure := s.runCfg + s.runMu.Unlock() + assert.Nil(t, cfgAfterFailure) + + // Free the port and reload with the same config; Start should fire again. + require.NoError(t, blocker.Close()) + require.NoError(t, s.reload(c, false)) + + waitForListening(t, "127.0.0.1:"+port) + require.NotNil(t, currentRuntime(s)) + + s.Stop() +} + +func waitForListening(t *testing.T, addr string) { + t.Helper() + waitFor(t, func() bool { + conn, err := net.DialTimeout("tcp", addr, 200*time.Millisecond) + if err != nil { + return false + } + _ = conn.Close() + return true + }) +} + +// graphiteSink is a minimal TCP accept-and-discard server so graphite.Once +// calls in tests don't spam error logs or wedge on connection refused. +type graphiteSink struct { + ln net.Listener +} + +func newGraphiteSink(t *testing.T) *graphiteSink { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + g := &graphiteSink{ln: ln} + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + _, _ = io.Copy(io.Discard, c) + _ = c.Close() + }(conn) + } + }() + return g +} + +func (g *graphiteSink) Addr() string { return g.ln.Addr().String() } +func (g *graphiteSink) Close() { _ = g.ln.Close() } + +func freeTCPPort(t *testing.T) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := ln.Addr().(*net.TCPAddr).Port + require.NoError(t, ln.Close()) + return strconv.Itoa(port) +} diff --git a/test/logger.go b/test/logger.go index b5a717d8..faab0b69 100644 --- a/test/logger.go +++ b/test/logger.go @@ -1,29 +1,73 @@ package test import ( + "context" "io" + "log/slog" "os" + "time" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/logging" ) -func NewLogger() *logrus.Logger { - l := logrus.New() - +// NewLogger returns a *slog.Logger suitable for use in tests. Output goes to +// io.Discard by default; set TEST_LOGS=1 (info), 2 (debug), or 3 (trace) to +// stream output to stderr for local debugging. +func NewLogger() *slog.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(io.Discard) - return l + return slog.New(slog.DiscardHandler) } + level := slog.LevelInfo switch v { case "2": - l.SetLevel(logrus.DebugLevel) + level = slog.LevelDebug case "3": - l.SetLevel(logrus.TraceLevel) - default: - l.SetLevel(logrus.InfoLevel) + level = logging.LevelTrace } - - return l + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} + +// NewLoggerWithOutput returns a *slog.Logger whose text output is captured by +// w. Timestamps are suppressed so tests can assert on exact output without +// baking the current time into expected strings. +func NewLoggerWithOutput(w io.Writer) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, nil)}) +} + +// NewLoggerWithOutputAndLevel is NewLoggerWithOutput with an explicit level +// so tests can exercise Enabled-gated paths. +func NewLoggerWithOutputAndLevel(w io.Writer, level slog.Level) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, &slog.HandlerOptions{Level: level})}) +} + +// NewJSONLoggerWithOutput returns a *slog.Logger emitting JSON to w with +// timestamps suppressed, for tests that pin the JSON shape. +func NewJSONLoggerWithOutput(w io.Writer, level slog.Level) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewJSONHandler(w, &slog.HandlerOptions{Level: level})}) +} + +// stripTimeHandler zeros each record's time before delegating so slog's +// built-in handlers skip emitting the time attribute. Used to avoid +// timestamp-dependent assertions in tests without resorting to ReplaceAttr. +type stripTimeHandler struct { + inner slog.Handler +} + +func (h *stripTimeHandler) Enabled(ctx context.Context, l slog.Level) bool { + return h.inner.Enabled(ctx, l) +} + +func (h *stripTimeHandler) Handle(ctx context.Context, r slog.Record) error { + r.Time = time.Time{} + return h.inner.Handle(ctx, r) +} + +func (h *stripTimeHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &stripTimeHandler{inner: h.inner.WithAttrs(attrs)} +} + +func (h *stripTimeHandler) WithGroup(name string) slog.Handler { + return &stripTimeHandler{inner: h.inner.WithGroup(name)} } diff --git a/timeout_test.go b/timeout_test.go index db36fec7..ffeecc55 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -134,7 +134,7 @@ func TestTimerWheel_Purge(t *testing.T) { assert.True(t, tw.lastTick.After(lastTick)) // Make sure we get all 4 packets back - for i := 0; i < 4; i++ { + for i := range 4 { p, has := tw.Purge() assert.True(t, has) assert.Equal(t, fps[i], p) @@ -149,7 +149,7 @@ func TestTimerWheel_Purge(t *testing.T) { // Make sure we cached the free'd items assert.Equal(t, 4, tw.itemsCached) ci := tw.itemCache - for i := 0; i < 4; i++ { + for range 4 { assert.NotNil(t, ci) ci = ci.Next } diff --git a/udp/conn.go b/udp/conn.go index 1ae585c2..30d89dec 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -16,7 +16,7 @@ type EncReader func( type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) + ListenOut(r EncReader) error WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool @@ -31,8 +31,8 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) { - return +func (NoopConn) ListenOut(_ EncReader) error { + return nil } func (NoopConn) SupportsMultipleReaders() bool { return false diff --git a/udp/udp_android.go b/udp/udp_android.go index bb191954..3fc68003 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -9,11 +9,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index 65ef31a5..c42a3c18 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -12,11 +12,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 91201194..8a4f5b18 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -8,12 +8,12 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net" "net/netip" "syscall" "unsafe" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) @@ -22,12 +22,12 @@ type StdConn struct { *net.UDPConn isV4 bool sysFd uintptr - l *logrus.Logger + l *slog.Logger } var _ Conn = &StdConn{} -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) for { @@ -173,11 +173,10 @@ func (u *StdConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } - u.l.WithError(err).Error("unexpected udp socket receive error") + u.l.Error("unexpected udp socket receive error", "error", err) } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) @@ -197,7 +196,7 @@ func (u *StdConn) Rebind() error { } if err != nil { - u.l.WithError(err).Error("Failed to rebind udp socket") + u.l.Error("Failed to rebind udp socket", "error", err) } return nil diff --git a/udp/udp_generic.go b/udp/udp_generic.go index e9dad6c5..131eb73b 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -12,22 +12,22 @@ import ( "context" "errors" "fmt" + "log/slog" "net" "net/netip" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) type GenericConn struct { *net.UDPConn - l *logrus.Logger + l *slog.Logger } var _ Conn = &GenericConn{} -func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewGenericListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -73,7 +73,7 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) { +func (u *GenericConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -83,13 +83,12 @@ func (u *GenericConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() - u.l.WithError(err).Warn("unexpected udp socket receive error") + u.l.Warn("unexpected udp socket receive error", "error", err) } continue } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e7759329..3e2d726a 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -4,72 +4,73 @@ package udp import ( + "context" "encoding/binary" "fmt" + "log/slog" "net" "net/netip" "syscall" "unsafe" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int + udpConn *net.UDPConn + rawConn syscall.RawConn + isV4 bool + l *slog.Logger + batch int } -func maybeIPV4(ip net.IP) (net.IP, bool) { - ip4 := ip.To4() - if ip4 != nil { - return ip4, true - } - return ip, false -} - -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - af := unix.AF_INET6 - if ip.Is4() { - af = unix.AF_INET - } - syscall.ForkLock.RLock() - fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) - if err == nil { - unix.CloseOnExec(fd) - } - syscall.ForkLock.RUnlock() - +func setReusePort(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + //CloseOnExec already set by the runtime + }) + if err != nil { + return err + } + return opErr +} + +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { + listen := netip.AddrPortFrom(ip, uint16(port)) + lc := net.ListenConfig{} + if multi { + lc.Control = setReusePort + } + //this context is only used during the bind operation, you can't cancel it to kill the socket + pc, err := lc.ListenPacket(context.Background(), "udp", listen.String()) if err != nil { - unix.Close(fd) return nil, fmt.Errorf("unable to open socket: %s", err) } - - if multi { - if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { - return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) - } + udpConn := pc.(*net.UDPConn) + rawConn, err := udpConn.SyscallConn() + if err != nil { + _ = udpConn.Close() + return nil, err + } + //gotta find out if we got an AF_INET6 socket or not: + out := &StdConn{ + udpConn: udpConn, + rawConn: rawConn, + l: l, + batch: batch, } - var sa unix.Sockaddr - if ip.Is4() { - sa4 := &unix.SockaddrInet4{Port: port} - sa4.Addr = ip.As4() - sa = sa4 - } else { - sa6 := &unix.SockaddrInet6{Port: port} - sa6.Addr = ip.As16() - sa = sa6 - } - if err = unix.Bind(fd, sa); err != nil { - return nil, fmt.Errorf("unable to bind to socket: %s", err) + af, err := out.getSockOptInt(unix.SO_DOMAIN) + if err != nil { + _ = out.Close() + return nil, err } + out.isV4 = af == unix.AF_INET - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return out, nil } func (u *StdConn) SupportsMultipleReaders() bool { @@ -80,62 +81,133 @@ func (u *StdConn) Rebind() error { return nil } +func (u *StdConn) getSockOptInt(opt int) (int, error) { + if u.rawConn == nil { + return 0, fmt.Errorf("no UDP connection") + } + var out int + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + out, opErr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, opt) + }) + if err != nil { + return 0, err + } + return out, opErr +} + +func (u *StdConn) setSockOptInt(opt int, n int) error { + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, opt, n) + }) + if err != nil { + return err + } + return opErr +} + func (u *StdConn) SetRecvBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) + return u.setSockOptInt(unix.SO_RCVBUFFORCE, n) } func (u *StdConn) SetSendBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) + return u.setSockOptInt(unix.SO_SNDBUFFORCE, n) } func (u *StdConn) SetSoMark(mark int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) + return u.setSockOptInt(unix.SO_MARK, mark) } func (u *StdConn) GetRecvBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) + return u.getSockOptInt(unix.SO_RCVBUF) } func (u *StdConn) GetSendBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) + return u.getSockOptInt(unix.SO_SNDBUF) } func (u *StdConn) GetSoMark() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) + return u.getSockOptInt(unix.SO_MARK) } func (u *StdConn) LocalAddr() (netip.AddrPort, error) { - sa, err := unix.Getsockname(u.sysFd) - if err != nil { - return netip.AddrPort{}, err - } + a := u.udpConn.LocalAddr() - switch sa := sa.(type) { - case *unix.SockaddrInet4: - return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil - - case *unix.SockaddrInet6: - return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil + switch v := a.(type) { + case *net.UDPAddr: + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: - return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } -func (u *StdConn) ListenOut(r EncReader) { +func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { + var errno syscall.Errno + n, _, errno := unix.Syscall6( + unix.SYS_RECVMMSG, + fd, + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + unix.MSG_WAITFORONE, + 0, + 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + // No data available, block for I/O and try again. + return int(n), false, nil + } + if errno != 0 { + return int(n), true, &net.OpError{Op: "recvmmsg", Err: errno} + } + return int(n), true, nil +} + +func (u *StdConn) listenOutSingle(r EncReader) error { + var err error + var n int + var from netip.AddrPort + buffer := make([]byte, MTU) + + for { + n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer) + if err != nil { + return err + } + from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) + r(from, buffer[:n]) + } +} + +func (u *StdConn) listenOutBatch(r EncReader) error { var ip netip.Addr + var n int + var operr error msgs, buffers, names := u.PrepareRawMessages(u.batch) - read := u.ReadMulti - if u.batch == 1 { - read = u.ReadSingle + + //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read + //defining it outside the loop so it gets re-used + reader := func(fd uintptr) (done bool) { + n, done, operr = recvmmsg(fd, msgs) + return done } for { - n, err := read(msgs) + err := u.rawConn.Read(reader) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err + } + if operr != nil { + return operr } for i := 0; i < n; i++ { @@ -150,106 +222,17 @@ func (u *StdConn) ListenOut(r EncReader) { } } -func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&(msgs[0].Hdr))), - 0, - 0, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmsg", Err: err} - } - - msgs[0].Len = uint32(n) - return 1, nil - } -} - -func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), - unix.MSG_WAITFORONE, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmmsg", Err: err} - } - - return int(n), nil +func (u *StdConn) ListenOut(r EncReader) error { + if u.batch == 1 { + return u.listenOutSingle(r) + } else { + return u.listenOutBatch(r) } } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { - if u.isV4 { - return u.writeTo4(b, ip) - } - return u.writeTo6(b, ip) -} - -func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { - var rsa unix.RawSockaddrInet6 - rsa.Family = unix.AF_INET6 - rsa.Addr = ip.Addr().As16() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet6), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } -} - -func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { - if !ip.Addr().Is4() { - return ErrInvalidIPv6RemoteForSocket - } - - var rsa unix.RawSockaddrInet4 - rsa.Family = unix.AF_INET - rsa.Addr = ip.Addr().As4() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet4), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } + _, err := u.udpConn.WriteToUDPAddrPort(b, ip) + return err } func (u *StdConn) ReloadConfig(c *config.C) { @@ -259,12 +242,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetRecvBuffer() if err == nil { - u.l.WithField("size", s).Info("listen.read_buffer was set") + u.l.Info("listen.read_buffer was set", "size", s) } else { - u.l.WithError(err).Warn("Failed to get listen.read_buffer") + u.l.Warn("Failed to get listen.read_buffer", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.read_buffer") + u.l.Error("Failed to set listen.read_buffer", "error", err) } } @@ -274,12 +257,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetSendBuffer() if err == nil { - u.l.WithField("size", s).Info("listen.write_buffer was set") + u.l.Info("listen.write_buffer was set", "size", s) } else { - u.l.WithError(err).Warn("Failed to get listen.write_buffer") + u.l.Warn("Failed to get listen.write_buffer", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.write_buffer") + u.l.Error("Failed to set listen.write_buffer", "error", err) } } @@ -290,27 +273,40 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetSoMark() if err == nil { - u.l.WithField("mark", s).Info("listen.so_mark was set") + u.l.Info("listen.so_mark was set", "mark", s) } else { - u.l.WithError(err).Warn("Failed to get listen.so_mark") + u.l.Warn("Failed to get listen.so_mark", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.so_mark") + u.l.Error("Failed to set listen.so_mark", "error", err) } } } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { var vallen uint32 = 4 * unix.SK_MEMINFO_VARS - _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) - if err != 0 { + + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + _, _, syserr := unix.Syscall6(unix.SYS_GETSOCKOPT, fd, uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) + if syserr != 0 { + opErr = syserr + } + }) + if err != nil { return err } - return nil + return opErr } func (u *StdConn) Close() error { - return syscall.Close(u.sysFd) + if u.udpConn != nil { + return u.udpConn.Close() + } + return nil } func NewUDPStatsEmitter(udpConns []Conn) func() { diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 3b69159a..4b2de75a 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -11,11 +11,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 3d60f34c..d110af19 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/netip" "sync" @@ -17,7 +18,6 @@ import ( "time" "unsafe" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" @@ -53,14 +53,14 @@ type ringBuffer struct { type RIOConn struct { isOpen atomic.Bool - l *logrus.Logger + l *slog.Logger sock windows.Handle rx, tx ringBuffer rq winrio.Rq results [packetsPerRing]winrio.Result } -func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { +func NewRIOListener(l *slog.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } @@ -83,7 +83,7 @@ func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, erro return u, nil } -func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { +func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error { var err error u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) if err != nil { @@ -103,7 +103,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. - l.WithError(err).Debug("failed to set UDP_CONNRESET ioctl") + l.Debug("failed to set UDP_CONNRESET ioctl", "error", err) } ret = 0 @@ -114,7 +114,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. - l.WithError(err).Debug("failed to set UDP_NETRESET ioctl") + l.Debug("failed to set UDP_NETRESET ioctl", "error", err) } err = u.rx.Open() @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) { +func (u *RIOConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -151,13 +151,12 @@ func (u *RIOConn) ListenOut(r EncReader) { if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() - u.l.WithError(err).Warn("unexpected udp socket receive error") + u.l.Warn("unexpected udp socket receive error", "error", err) } continue } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5f0f7765..fcd0967c 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -4,11 +4,13 @@ package udp import ( + "context" "io" + "log/slog" "net/netip" - "sync/atomic" + "os" + "sync" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -36,15 +38,24 @@ type TesterConn struct { RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula - closed atomic.Bool - l *logrus.Logger + // done is closed exactly once by Close. Senders select on it so they + // never race with a channel close; readers exit when it fires. The + // packet channels are intentionally never closed - that was the source + // of `send on closed channel` panics when a WriteTo/Send from another + // goroutine passed the close check and reached the send just after + // Close ran. + done chan struct{} + closeOnce sync.Once + + l *slog.Logger } -func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), + done: make(chan struct{}), l: l, }, nil } @@ -53,21 +64,21 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *TesterConn) Send(packet *Packet) { - if u.closed.Load() { - return - } - h := &header.H{} if err := h.Parse(packet.Data); err != nil { panic(err) } - if u.l.Level >= logrus.DebugLevel { - u.l.WithField("header", h). - WithField("udpAddr", packet.From). - WithField("dataLen", len(packet.Data)). - Debug("UDP receiving injected packet") + if u.l.Enabled(context.Background(), slog.LevelDebug) { + u.l.Debug("UDP receiving injected packet", + "header", h, + "udpAddr", packet.From, + "dataLen", len(packet.Data), + ) + } + select { + case <-u.done: + case u.RxPackets <- packet: } - u.RxPackets <- packet } // Get will pull a UdpPacket from the transmit queue @@ -75,7 +86,12 @@ func (u *TesterConn) Send(packet *Packet) { // packets were ingested from the tun side (in most cases), you can send them with Tun.Send func (u *TesterConn) Get(block bool) *Packet { if block { - return <-u.TxPackets + select { + case <-u.done: + return nil + case p := <-u.TxPackets: + return p + } } select { @@ -91,10 +107,6 @@ func (u *TesterConn) Get(block bool) *Packet { //********************************************************************************************************************// func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { - if u.closed.Load() { - return io.ErrClosedPipe - } - p := &Packet{ Data: make([]byte, len(b), len(b)), From: u.Addr, @@ -102,17 +114,22 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { } copy(p.Data, b) - u.TxPackets <- p - return nil + select { + case <-u.done: + return io.ErrClosedPipe + case u.TxPackets <- p: + return nil + } } -func (u *TesterConn) ListenOut(r EncReader) { +func (u *TesterConn) ListenOut(r EncReader) error { for { - p, ok := <-u.RxPackets - if !ok { - return + select { + case <-u.done: + return os.ErrClosed + case p := <-u.RxPackets: + r(p.From, p.Data) } - r(p.From, p.Data) } } @@ -136,9 +153,8 @@ func (u *TesterConn) Rebind() error { } func (u *TesterConn) Close() error { - if u.closed.CompareAndSwap(false, true) { - close(u.RxPackets) - close(u.TxPackets) - } + u.closeOnce.Do(func() { + close(u.done) + }) return nil } diff --git a/udp/udp_windows.go b/udp/udp_windows.go index 1b777c37..7969f7e8 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -5,14 +5,13 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - - "github.com/sirupsen/logrus" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between @@ -25,7 +24,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return rc, nil } - l.WithError(err).Error("Falling back to standard udp sockets") + l.Error("Falling back to standard udp sockets", "error", err) return NewGenericListener(l, ip, port, multi, batch) } diff --git a/util/error.go b/util/error.go index 814c77a1..14371d3f 100644 --- a/util/error.go +++ b/util/error.go @@ -1,10 +1,10 @@ package util import ( + "context" "errors" "fmt" - - "github.com/sirupsen/logrus" + "log/slog" ) type ContextualError struct { @@ -28,12 +28,12 @@ func ContextualizeIfNeeded(msg string, err error) error { } // LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError -func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { +func LogWithContextIfNeeded(msg string, err error, l *slog.Logger) { switch v := err.(type) { case *ContextualError: v.Log(l) default: - l.WithError(err).Error(msg) + l.Error(msg, "error", err) } } @@ -51,10 +51,19 @@ func (ce *ContextualError) Unwrap() error { return ce.RealError } -func (ce *ContextualError) Log(lr *logrus.Logger) { - if ce.RealError != nil { - lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) - } else { - lr.WithFields(ce.Fields).Error(ce.Context) +// Log emits ce as a single error-level log line with Fields and RealError +// promoted to top-level attributes, producing a flat shape callers can grep +// or parse without walking into a nested object. +func (ce *ContextualError) Log(l *slog.Logger) { + attrs := make([]slog.Attr, 0, len(ce.Fields)+1) + for k, v := range ce.Fields { + attrs = append(attrs, slog.Any(k, v)) } + if ce.RealError != nil { + attrs = append(attrs, slog.Any("error", ce.RealError)) + } + // LogAttrs is intentional: attrs is built from a map[string]any so it has + // no pair-form equivalent. + //nolint:sloglint + l.LogAttrs(context.Background(), slog.LevelError, ce.Context, attrs...) } diff --git a/util/error_test.go b/util/error_test.go index 692c1840..30e39e33 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -1,95 +1,67 @@ package util import ( + "bytes" "errors" "fmt" "testing" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) type m = map[string]any -type TestLogWriter struct { - Logs []string -} - -func NewTestLogWriter() *TestLogWriter { - return &TestLogWriter{Logs: make([]string, 0)} -} - -func (tl *TestLogWriter) Write(p []byte) (n int, err error) { - tl.Logs = append(tl.Logs, string(p)) - return len(p), nil -} - -func (tl *TestLogWriter) Reset() { - tl.Logs = tl.Logs[:0] -} - func TestContextualError_Log(t *testing.T) { - l := logrus.New() - l.Formatter = &logrus.TextFormatter{ - DisableTimestamp: true, - DisableColors: true, - } - - tl := NewTestLogWriter() - l.Out = tl + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutput(buf) // Test a full context line - tl.Reset() + buf.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String()) // Test a line with an error and msg but no fields - tl.Reset() + buf.Reset() e = NewContextualError("test message", nil, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" error=error\n", buf.String()) // Test just a context and fields - tl.Reset() + buf.Reset() e = NewContextualError("test message", m{"field": "1"}, nil) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1\n", buf.String()) // Test just a context - tl.Reset() + buf.Reset() e = NewContextualError("test message", nil, nil) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\"\n", buf.String()) // Test just an error - tl.Reset() + buf.Reset() e = NewContextualError("", nil, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"\" error=error\n", buf.String()) } func TestLogWithContextIfNeeded(t *testing.T) { - l := logrus.New() - l.Formatter = &logrus.TextFormatter{ - DisableTimestamp: true, - DisableColors: true, - } - - tl := NewTestLogWriter() - l.Out = tl + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutput(buf) // Test ignoring fallback context - tl.Reset() + buf.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) LogWithContextIfNeeded("This should get thrown away", e, l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String()) // Test using fallback context - tl.Reset() + buf.Reset() err := fmt.Errorf("this is a normal error") LogWithContextIfNeeded("Fallback context woo", err, l) - assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"Fallback context woo\" error=\"this is a normal error\"\n", buf.String()) } func TestContextualizeIfNeeded(t *testing.T) {