Merge remote-tracking branch 'origin/master' into fips140

This commit is contained in:
Wade Simmons
2026-04-27 14:13:47 -04:00
117 changed files with 6608 additions and 2221 deletions

View File

@@ -209,10 +209,11 @@ jobs:
id: create_release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_REF_NAME: ${{ github.ref_name }}
run: |
cd artifacts
gh release create \
--verify-tag \
--title "Release ${{ github.ref_name }}" \
"${{ github.ref_name }}" \
--title "Release ${GITHUB_REF_NAME}" \
"${GITHUB_REF_NAME}" \
SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz

View File

@@ -18,6 +18,8 @@ jobs:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: Run extra smoke tests
runs-on: ubuntu-latest
env:
VAGRANT_DEFAULT_PROVIDER: libvirt
steps:
- uses: actions/checkout@v6
@@ -30,8 +32,13 @@ jobs:
- name: add hashicorp source
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
- name: install vagrant
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
- name: install vagrant and libvirt
run: |
sudo apt-get update && sudo apt-get install -y vagrant libvirt-daemon-system libvirt-dev
sudo chmod 666 /dev/kvm
sudo usermod -aG libvirt $(whoami)
sudo chmod 666 /var/run/libvirt/libvirt-sock
vagrant plugin install vagrant-libvirt
- name: freebsd-amd64
run: make smoke-vagrant/freebsd-amd64
@@ -42,10 +49,19 @@ jobs:
- name: netbsd-amd64
run: make smoke-vagrant/netbsd-amd64
- name: linux-386
run: make smoke-vagrant/linux-386
- name: linux-amd64-ipv6disable
run: make smoke-vagrant/linux-amd64-ipv6disable
# linux-386 runs last because it requires disabling KVM to use VirtualBox,
# which prevents libvirt (used by the other tests) from working after this point.
- name: install virtualbox for i386 test
run: |
sudo apt-get install -y virtualbox
sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true
- name: linux-386
env:
VAGRANT_DEFAULT_PROVIDER: virtualbox
run: make smoke-vagrant/linux-386
timeout-minutes: 30

View File

@@ -16,8 +16,10 @@ relay:
am_relay: true
EOF
export LIGHTHOUSES="192.168.100.1 172.17.0.2:4242"
export REMOTE_ALLOW_LIST='{"172.17.0.4/32": false, "172.17.0.5/32": false}'
# TEST-NET-3 placeholder IPs; smoke-relay.sh seds them to real container IPs.
# Mapping: .2 lighthouse1, .3 host2, .4 host3, .5 host4.
export LIGHTHOUSES="192.168.100.1 203.0.113.2:4242"
export REMOTE_ALLOW_LIST='{"203.0.113.4/32": false, "203.0.113.5/32": false}'
HOST="host2" ../genconfig.sh >host2.yml <<EOF
relay:
@@ -25,7 +27,7 @@ relay:
- 192.168.100.1
EOF
export REMOTE_ALLOW_LIST='{"172.17.0.3/32": false}'
export REMOTE_ALLOW_LIST='{"203.0.113.3/32": false}'
HOST="host3" ../genconfig.sh >host3.yml

View File

@@ -5,9 +5,15 @@ set -e -x
rm -rf ./build
mkdir ./build
# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
# - We could make this better by launching the lighthouse first and then fetching what IP it is.
NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
# Smoke containers run on a dedicated docker network whose subnet is allocated
# at smoke time, not known at build time. Configs are written with TEST-NET-3
# placeholder IPs (RFC 5737) and smoke.sh / smoke-vagrant.sh / smoke-relay.sh
# sed the real container IPs in before starting nebula.
#
# Placeholder mapping (last octet == fixed container slot):
# 203.0.113.2 -> lighthouse1, 203.0.113.3 -> host2,
# 203.0.113.4 -> host3, 203.0.113.5 -> host4.
LIGHTHOUSE_IP="203.0.113.2"
(
cd build
@@ -25,16 +31,16 @@ NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{
../genconfig.sh >lighthouse1.yml
HOST="host2" \
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
../genconfig.sh >host2.yml
HOST="host3" \
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host3.yml
HOST="host4" \
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host4.yml

View File

@@ -6,6 +6,8 @@ set -o pipefail
mkdir -p logs
NETWORK="nebula-smoke-relay"
cleanup() {
echo
echo " *** cleanup"
@@ -16,22 +18,53 @@ cleanup() {
then
docker kill lighthouse1 host2 host3 host4
fi
docker network rm "$NETWORK" >/dev/null 2>&1
}
trap cleanup EXIT
docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
# Create a dedicated smoke network with an explicit subnet (required for --ip
# below). Probe a short list of candidates so a locally-used range doesn't
# fail the whole test — we only need one to be free.
docker network rm "$NETWORK" >/dev/null 2>&1 || true
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
break
fi
done
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
exit 1
fi
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
# .3 host2, .4 host3, .5 host4 — matches the placeholders in build-relay.sh.
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
PREFIX="${SUBNET%/*}"
PREFIX="${PREFIX%.*}"
LIGHTHOUSE_IP="$PREFIX.2"
HOST2_IP="$PREFIX.3"
HOST3_IP="$PREFIX.4"
HOST4_IP="$PREFIX.5"
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
for f in build/host2.yml build/host3.yml build/host4.yml; do
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
mv "$f.tmp" "$f"
done
docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" nebula:smoke-relay -config host2.yml -test
docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" nebula:smoke-relay -config host3.yml -test
docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" nebula:smoke-relay -config host4.yml -test
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 1
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
sleep 1
set +x
@@ -76,7 +109,13 @@ docker exec host4 sh -c 'kill 1'
docker exec host3 sh -c 'kill 1'
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
sleep 5
# Wait up to 30s for all backgrounded jobs to exit rather than relying on a
# fixed sleep.
for _ in $(seq 1 30); do
[ -z "$(jobs -r)" ] && break
sleep 1
done
if [ "$(jobs -r)" ]
then

View File

@@ -8,6 +8,8 @@ export VAGRANT_CWD="$PWD/vagrant-$1"
mkdir -p logs
NETWORK="nebula-smoke"
cleanup() {
echo
echo " *** cleanup"
@@ -19,21 +21,51 @@ cleanup() {
docker kill lighthouse1 host2
fi
vagrant destroy -f
docker network rm "$NETWORK" >/dev/null 2>&1
}
trap cleanup EXIT
# Create a dedicated smoke network with an explicit subnet (required for --ip
# below). Probe a short list of candidates so a locally-used range doesn't
# fail the whole test — we only need one to be free.
docker network rm "$NETWORK" >/dev/null 2>&1 || true
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
break
fi
done
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
exit 1
fi
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
# .3 host2 — matches the placeholders in build.sh.
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
PREFIX="${SUBNET%/*}"
PREFIX="${PREFIX%.*}"
LIGHTHOUSE_IP="$PREFIX.2"
HOST2_IP="$PREFIX.3"
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
# This must happen before `vagrant up` rsyncs build/ into the VM for host3.
for f in build/host2.yml build/host3.yml; do
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
mv "$f.tmp" "$f"
done
CONTAINER="nebula:${NAME:-smoke}"
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test
vagrant up
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 15
@@ -96,7 +128,14 @@ vagrant ssh -c "ping -c1 192.168.100.2" -- -T
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
# Wait up to 30s for all backgrounded jobs to exit. vagrant ssh in particular
# takes a beat to tear down after nebula exits on the VM, so a fixed sleep is
# racy.
for _ in $(seq 1 30); do
[ -z "$(jobs -r)" ] && break
sleep 1
done
if [ "$(jobs -r)" ]
then

View File

@@ -6,6 +6,8 @@ set -o pipefail
mkdir -p logs
NETWORK="nebula-smoke"
cleanup() {
echo
echo " *** cleanup"
@@ -16,38 +18,71 @@ cleanup() {
then
docker kill lighthouse1 host2 host3 host4
fi
docker network rm "$NETWORK" >/dev/null 2>&1
}
trap cleanup EXIT
# Create a dedicated smoke network with an explicit subnet (required for --ip
# below). Probe a short list of candidates so a locally-used range doesn't
# fail the whole test — we only need one to be free.
docker network rm "$NETWORK" >/dev/null 2>&1 || true
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
break
fi
done
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
exit 1
fi
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
# .3 host2, .4 host3, .5 host4 — matches the placeholders in build.sh.
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
PREFIX="${SUBNET%/*}"
PREFIX="${PREFIX%.*}"
LIGHTHOUSE_IP="$PREFIX.2"
HOST2_IP="$PREFIX.3"
HOST3_IP="$PREFIX.4"
HOST4_IP="$PREFIX.5"
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
# build/lighthouse1.yml has no IPs to rewrite so it's skipped.
for f in build/host2.yml build/host3.yml build/host4.yml; do
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
mv "$f.tmp" "$f"
done
CONTAINER="nebula:${NAME:-smoke}"
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test
docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" "$CONTAINER" -config host3.yml -test
docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" "$CONTAINER" -config host4.yml -test
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 1
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
sleep 1
# grab tcpdump pcaps for debugging
docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
docker exec lighthouse1 tcpdump -i tun0 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
docker exec host2 tcpdump -i tun0 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
docker exec host3 tcpdump -i tun0 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
docker exec host4 tcpdump -i tun0 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
docker exec host2 ncat -nklv 0.0.0.0 2000 &
docker exec host3 ncat -nklv 0.0.0.0 2000 &
docker exec host4 ncat -e '/usr/bin/echo helloagainfromhost4' -nkluv 0.0.0.0 4000 &
docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
@@ -119,17 +154,24 @@ echo
echo " *** Testing conntrack"
echo
set -x
# host2 can ping host3 now that host3 pinged it first
docker exec host2 ping -c1 192.168.100.3
# host4 can ping host2 once conntrack established
docker exec host2 ping -c1 192.168.100.4
docker exec host4 ping -c1 192.168.100.2
# host4's outbound firewall only allows ICMP to the lighthouse, so host4
# cannot initiate UDP to host2. Once host2 initiates a flow to host4:4000,
# conntrack must let host4's listener reply on that flow. If it doesn't,
# the echo back from host4 never reaches host2.
docker exec host2 sh -c "(/usr/bin/echo host2; sleep 2) | ncat -nuv 192.168.100.4 4000" | grep -q helloagainfromhost4
docker exec host4 sh -c 'kill 1'
docker exec host3 sh -c 'kill 1'
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
sleep 5
# Wait up to 30s for all backgrounded jobs to exit rather than relying on a
# fixed sleep.
for _ in $(seq 1 30); do
[ -z "$(jobs -r)" ] && break
sleep 1
done
if [ "$(jobs -r)" ]
then

View File

@@ -1,7 +1,7 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
Vagrant.configure("2") do |config|
config.vm.box = "ubuntu/jammy64"
config.vm.box = "bento/ubuntu-24.04"
config.vm.synced_folder "../build", "/nebula"

View File

@@ -1,7 +1,7 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
Vagrant.configure("2") do |config|
config.vm.box = "generic/openbsd7"
config.vm.box = "DefinedNet/openbsd78"
config.vm.synced_folder "../build", "/nebula", type: "rsync"
end

View File

@@ -2,7 +2,21 @@ version: "2"
linters:
default: none
enable:
- sloglint
- testifylint
settings:
sloglint:
# Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and
# the package-level slog equivalents. Use l.Log(ctx, level, ...) for
# custom levels instead of LogAttrs when you can.
#
# LogAttrs is also flagged by this rule because it takes ...slog.Attr;
# the few legitimate sites (where attrs is built up as a []slog.Attr)
# carry a //nolint:sloglint with rationale.
kv-only: true
# no-mixed-args is on by default: forbids mixing kv and attrs in one call.
# discard-handler is on by default (since Go 1.24): suggests
# slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil).
exclusions:
generated: lax
presets:

View File

@@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [1.10.3] - 2026-02-06
### Security
- Fix an issue where blocklist bypass is possible when using curve P256 since the signature can have 2 valid representations.
Both fingerprint representations will be tested against the blocklist.
Any newly issued P256 based certificates will have their signature clamped to the low-s form.
Nebula will assert the low-s signature form when validating certificates in a future version. [GHSA-69x3-g4r3-p962](https://github.com/slackhq/nebula/security/advisories/GHSA-69x3-g4r3-p962)
### Changed
- Improve error reporting if nebula fails to start due to a tun device naming issue. (#1588)
## [1.10.2] - 2026-01-21
### Fixed
@@ -775,7 +788,8 @@ created.)
- Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.2...HEAD
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.3...HEAD
[1.10.3]: https://github.com/slackhq/nebula/releases/tag/v1.10.3
[1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2
[1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0

1
CODEOWNERS Normal file
View File

@@ -0,0 +1 @@
#ECCN:Open Source

View File

@@ -57,7 +57,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
docker pull nebulaoss/nebula
```
#### Mobile
#### Mobile ([source code](https://github.com/DefinedNet/mobile_nebula))
- [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&amp;itscg=30200)
- [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1)
@@ -76,6 +76,8 @@ Nebula was created to provide a mechanism for groups of hosts to communicate sec
## Getting started (quickly)
**Don't want to manage your own PKI and lighthouses?** [Managed Nebula](https://www.defined.net/) from Defined Networking handles all of this for you.
To set up a Nebula network, you'll need:
#### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use.

38
bits.go
View File

@@ -1,8 +1,10 @@
package nebula
import (
"context"
"log/slog"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
)
type Bits struct {
@@ -30,7 +32,7 @@ func NewBits(bits uint64) *Bits {
return b
}
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
// If i is the next number, return true.
if i > b.current {
return true
@@ -42,13 +44,16 @@ func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
}
// Not within the window
if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("rejected a packet (top)",
"current", b.current,
"incoming", i,
)
}
return false
}
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
@@ -87,9 +92,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// Check to see if it's a duplicate
if i > b.current-b.length || i < b.length && b.current < b.length {
if b.current == i || b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window",
"accepted", false,
"currentCounter", b.current,
"incomingCounter", i,
"reason", "duplicate",
)
}
b.dupeCounter.Inc(1)
return false
@@ -101,12 +110,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// In all other cases, fail and don't change current.
b.outOfWindowCounter.Inc(1)
if l.Level >= logrus.DebugLevel {
l.WithField("accepted", false).
WithField("currentCounter", b.current).
WithField("incomingCounter", i).
WithField("reason", "nonsense").
Debug("Receive window")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window",
"accepted", false,
"currentCounter", b.current,
"incomingCounter", i,
"reason", "nonsense",
)
}
return false
}

View File

@@ -1,5 +1,4 @@
//go:build boringcrypto
// +build boringcrypto
package nebula

View File

@@ -1,11 +1,14 @@
package cert
import (
"bufio"
"bytes"
"encoding/pem"
"errors"
"fmt"
"io"
"net/netip"
"slices"
"strings"
"time"
)
@@ -29,22 +32,46 @@ func NewCAPool() *CAPool {
// If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
pool := NewCAPool()
var err error
var expired bool
for {
caPEMs, err = pool.AddCAFromPEM(caPEMs)
if errors.Is(err, ErrExpired) {
expired = true
err = nil
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
}
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
// The reader must contain a PEM-encoded set of nebula certificates.
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
pool := NewCAPool()
var expired bool
scanner := bufio.NewScanner(r)
scanner.Split(SplitPEM)
for scanner.Scan() {
pemBytes := scanner.Bytes()
block, rest := pem.Decode(pemBytes)
if len(bytes.TrimSpace(rest)) > 0 {
return nil, ErrInvalidPEMBlock
}
if block == nil {
return nil, ErrInvalidPEMBlock
}
c, err := unmarshalCertificateBlock(block)
if err != nil {
return nil, err
}
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break
err = pool.AddCA(c)
if errors.Is(err, ErrExpired) {
expired = true
continue
} else if err != nil {
return nil, err
}
}
if err := scanner.Err(); err != nil {
return nil, ErrInvalidPEMBlock
}
if expired {
return pool, ErrExpired
@@ -141,10 +168,23 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti
return nil, err
}
// Pre nebula v1.10.3 could generate signatures in either high or low s form and validation
// of signatures allowed for either. Nebula v1.10.3 and beyond clamps signature generation to low-s form
// but validation still allows for either. Since a change in the signature bytes affects the fingerprint, we
// need to test both forms until such a time comes that we enforce low-s form on signature validation.
fp2, err := CalculateAlternateFingerprint(c)
if err != nil {
return nil, fmt.Errorf("could not calculate alternate fingerprint to verify: %w", err)
}
if fp2 != "" && ncp.IsBlocklisted(fp2) {
return nil, ErrBlockListed
}
cc := CachedCertificate{
Certificate: c,
InvertedGroups: make(map[string]struct{}),
Fingerprint: fp,
fingerprint2: fp2,
signerFingerprint: signer.Fingerprint,
}
@@ -158,6 +198,11 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti
// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
// is a cheaper operation to perform as a result.
func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
// Check any available alternate fingerprint forms for this certificate, re P256 high-s/low-s
if c.fingerprint2 != "" && ncp.IsBlocklisted(c.fingerprint2) {
return ErrBlockListed
}
_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
return err
}

View File

@@ -1,10 +1,14 @@
package cert
import (
"bytes"
"io"
"net/netip"
"strings"
"testing"
"time"
"github.com/slackhq/nebula/cert/p256"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -111,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
assert.Len(t, ppppp.CAs, 1)
}
// oneByteReader wraps a reader to return at most 1 byte per Read call,
// exercising the streaming accumulation logic in NewCAPoolFromPEMReader.
type oneByteReader struct {
r io.Reader
}
func (o *oneByteReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
return o.r.Read(p[:1])
}
func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) {
pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil))
require.NoError(t, err)
assert.Empty(t, pool.CAs)
pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n "))
require.NoError(t, err)
assert.Empty(t, pool.CAs)
}
func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) {
ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
bundle := append(pem1, pem2...)
pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)})
require.NoError(t, err)
assert.Len(t, pool.CAs, 2)
fp1, err := ca1.Fingerprint()
require.NoError(t, err)
fp2, err := ca2.Fingerprint()
require.NoError(t, err)
assert.Contains(t, pool.CAs, fp1)
assert.Contains(t, pool.CAs, fp2)
}
func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) {
_, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata"))
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
}
func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) {
_, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
bundle := append(pem1, []byte("some trailing garbage")...)
_, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle))
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
}
func TestCertificateV1_Verify(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
@@ -170,6 +228,15 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
_, err = caPool.VerifyCertificate(time.Now(), c)
require.EqualError(t, err, "certificate is in the block list")
// Create a copy of the cert and swap to the alternate form for the signature
nc := c.Copy()
b, err := p256.Swap(c.Signature())
require.NoError(t, err)
require.NoError(t, nc.(*certificateV1).setSignature(b))
_, err = caPool.VerifyCertificate(time.Now(), nc)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
@@ -187,7 +254,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
require.NoError(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
b, err = caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
@@ -196,7 +263,17 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
})
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c)
cc, err := caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Reset the blocklist and block the alternate form fingerprint
caPool.ResetCertBlocklist()
caPool.BlocklistFingerprint(cc.fingerprint2)
err = caPool.VerifyCachedCertificate(time.Now(), cc)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
err = caPool.VerifyCachedCertificate(time.Now(), cc)
require.NoError(t, err)
}
@@ -394,6 +471,15 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
_, err = caPool.VerifyCertificate(time.Now(), c)
require.EqualError(t, err, "certificate is in the block list")
// Create a copy of the cert and swap to the alternate form for the signature
nc := c.Copy()
b, err := p256.Swap(c.Signature())
require.NoError(t, err)
require.NoError(t, nc.(*certificateV2).setSignature(b))
_, err = caPool.VerifyCertificate(time.Now(), nc)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
@@ -411,7 +497,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
require.NoError(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
b, err = caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
@@ -420,7 +506,17 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
})
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c)
cc, err := caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Reset the blocklist and block the alternate form fingerprint
caPool.ResetCertBlocklist()
caPool.BlocklistFingerprint(cc.fingerprint2)
err = caPool.VerifyCachedCertificate(time.Now(), cc)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
err = caPool.VerifyCachedCertificate(time.Now(), cc)
require.NoError(t, err)
}

View File

@@ -4,6 +4,8 @@ import (
"fmt"
"net/netip"
"time"
"github.com/slackhq/nebula/cert/p256"
)
type Version uint8
@@ -110,6 +112,9 @@ type CachedCertificate struct {
InvertedGroups map[string]struct{}
Fingerprint string
signerFingerprint string
// A place to store a 2nd fingerprint if the certificate could have one, such as with P256
fingerprint2 string
}
func (cc *CachedCertificate) String() string {
@@ -152,3 +157,31 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
return c, nil
}
// CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates
// CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step.
func CalculateAlternateFingerprint(c Certificate) (string, error) {
if c.Curve() != Curve_P256 {
return "", nil
}
nc := c.Copy()
b, err := p256.Swap(nc.Signature())
if err != nil {
return "", err
}
switch v := nc.(type) {
case *certificateV1:
err = v.setSignature(b)
case *certificateV2:
err = v.setSignature(b)
default:
return "", ErrUnknownVersion
}
if err != nil {
return "", err
}
return nc.Fingerprint()
}

127
cert/p256/p256.go Normal file
View File

@@ -0,0 +1,127 @@
package p256
import (
"crypto/elliptic"
"errors"
"math/big"
"filippo.io/bigmod"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
var halfN = new(big.Int).Rsh(elliptic.P256().Params().N, 1)
var nMod *bigmod.Modulus
func init() {
n, err := bigmod.NewModulus(elliptic.P256().Params().N.Bytes())
if err != nil {
panic(err)
}
nMod = n
}
func IsNormalized(sig []byte) (bool, error) {
r, s, err := parseSignature(sig)
if err != nil {
return false, err
}
return checkLowS(r, s), nil
}
func checkLowS(_, s []byte) bool {
bigS := new(big.Int).SetBytes(s)
// Check if S <= (N/2), because we want to include the midpoint in the set of low-s
return bigS.Cmp(halfN) <= 0
}
func swap(r, s []byte) ([]byte, []byte, error) {
var err error
bigS, err := bigmod.NewNat().SetBytes(s, nMod)
if err != nil {
return nil, nil, err
}
sNormalized := nMod.Nat().Sub(bigS, nMod)
result := sNormalized.Bytes(nMod)
for len(result) > 1 && result[0] == 0 {
result = result[1:]
}
return r, result, nil
}
func Normalize(sig []byte) ([]byte, error) {
r, s, err := parseSignature(sig)
if err != nil {
return nil, err
}
if checkLowS(r, s) {
return sig, nil
}
newR, newS, err := swap(r, s)
if err != nil {
return nil, err
}
return encodeSignature(newR, newS)
}
// Swap will change sig between its current form to the opposite high or low form.
func Swap(sig []byte) ([]byte, error) {
r, s, err := parseSignature(sig)
if err != nil {
return nil, err
}
newR, newS, err := swap(r, s)
if err != nil {
return nil, err
}
return encodeSignature(newR, newS)
}
// parseSignature taken exactly from crypto/ecdsa/ecdsa.go
func parseSignature(sig []byte) (r, s []byte, err error) {
var inner cryptobyte.String
input := cryptobyte.String(sig)
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
!input.Empty() ||
!inner.ReadASN1Integer(&r) ||
!inner.ReadASN1Integer(&s) ||
!inner.Empty() {
return nil, nil, errors.New("invalid ASN.1")
}
return r, s, nil
}
func encodeSignature(r, s []byte) ([]byte, error) {
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
addASN1IntBytes(b, r)
addASN1IntBytes(b, s)
})
return b.Bytes()
}
// addASN1IntBytes encodes in ASN.1 a positive integer represented as
// a big-endian byte slice with zero or more leading zeroes.
func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
for len(bytes) > 0 && bytes[0] == 0 {
bytes = bytes[1:]
}
if len(bytes) == 0 {
b.SetError(errors.New("invalid integer"))
return
}
b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) {
if bytes[0]&0x80 != 0 {
c.AddUint8(0)
}
c.AddBytes(bytes)
})
}

28
cert/p256/p256_test.go Normal file
View File

@@ -0,0 +1,28 @@
package p256
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"testing"
"github.com/stretchr/testify/require"
)
func TestFlipping(t *testing.T) {
priv, err1 := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err1)
out, err := ecdsa.SignASN1(rand.Reader, priv, []byte("big chungus"))
require.NoError(t, err)
r, s, err := parseSignature(out)
require.NoError(t, err)
r, s1, err := swap(r, s)
require.NoError(t, err)
r, s2, err := swap(r, s1)
require.NoError(t, err)
require.Equal(t, s, s2)
require.NotEqual(t, s, s1)
}

View File

@@ -1,12 +1,66 @@
package cert
import (
"bytes"
"encoding/pem"
"errors"
"fmt"
"golang.org/x/crypto/ed25519"
)
var ErrTruncatedPEMBlock = errors.New("truncated PEM block")
// SplitPEM is a split function for bufio.Scanner that returns each PEM block.
func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) {
// Look for the start of a PEM block
start := bytes.Index(data, []byte("-----BEGIN "))
if start == -1 {
if atEOF && len(bytes.TrimSpace(data)) > 0 {
// Non-whitespace content with no PEM block
return 0, nil, ErrTruncatedPEMBlock
}
if atEOF {
return len(data), nil, nil
}
// Request more data
return 0, nil, nil
}
// Look for the end marker
endMarkerStart := bytes.Index(data[start:], []byte("-----END "))
if endMarkerStart == -1 {
if atEOF {
// Incomplete PEM block at EOF
return 0, nil, ErrTruncatedPEMBlock
}
// Need more data to find the end
return 0, nil, nil
}
// Find the actual end of the END line (after the newline)
endMarkerStart += start
endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n')
var end int
if endLineEnd == -1 {
if atEOF {
// END marker without newline at EOF - take it anyway
end = len(data)
} else {
// Need more data
return 0, nil, nil
}
} else {
end = endMarkerStart + endLineEnd + 1
}
// Extract the PEM block
pemBlock := data[start:end]
// Return the valid PEM block
return end, pemBlock, nil
}
const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
@@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
return nil, r, ErrInvalidPEMBlock
}
var c Certificate
var err error
switch p.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
c, err := unmarshalCertificateBlock(p)
if err != nil {
return nil, r, err
}
@@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
switch block.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
return unmarshalCertificateV1(block.Bytes, nil)
case CertificateV2Banner:
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
default:
return nil, ErrInvalidPEMCertificateBanner
}
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())

View File

@@ -1,12 +1,88 @@
package cert
import (
"bufio"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func scanAll(t *testing.T, input string) ([]string, error) {
t.Helper()
scanner := bufio.NewScanner(strings.NewReader(input))
scanner.Split(SplitPEM)
var blocks []string
for scanner.Scan() {
blocks = append(blocks, scanner.Text())
}
return blocks, scanner.Err()
}
func TestSplitPEM_Single(t *testing.T) {
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n"
blocks, err := scanAll(t, input)
require.NoError(t, err)
require.Len(t, blocks, 1)
require.Equal(t, input, blocks[0])
}
func TestSplitPEM_Multiple(t *testing.T) {
block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n"
block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
blocks, err := scanAll(t, block1+block2)
require.NoError(t, err)
require.Len(t, blocks, 2)
require.Equal(t, block1, blocks[0])
require.Equal(t, block2, blocks[1])
}
func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) {
input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
blocks, err := scanAll(t, input)
require.NoError(t, err)
require.Len(t, blocks, 2)
}
func TestSplitPEM_Empty(t *testing.T) {
blocks, err := scanAll(t, "")
require.NoError(t, err)
require.Empty(t, blocks)
}
func TestSplitPEM_WhitespaceOnly(t *testing.T) {
blocks, err := scanAll(t, " \n\t\n ")
require.NoError(t, err)
require.Empty(t, blocks)
}
func TestSplitPEM_TrailingGarbage(t *testing.T) {
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage"
blocks, err := scanAll(t, input)
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
require.Len(t, blocks, 1)
}
func TestSplitPEM_TruncatedBlock(t *testing.T) {
input := "-----BEGIN TEST-----\npartial data with no end"
_, err := scanAll(t, input)
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
}
func TestSplitPEM_NoEndNewline(t *testing.T) {
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----"
blocks, err := scanAll(t, input)
require.NoError(t, err)
require.Len(t, blocks, 1)
require.Equal(t, input, blocks[0])
}
func TestSplitPEM_GarbageOnly(t *testing.T) {
_, err := scanAll(t, "this is not PEM data")
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
}
func TestUnmarshalCertificateFromPEM(t *testing.T) {
goodCert := []byte(`
# A good cert

View File

@@ -9,6 +9,8 @@ import (
"fmt"
"net/netip"
"time"
"github.com/slackhq/nebula/cert/p256"
)
// TBSCertificate represents a certificate intended to be signed.
@@ -126,6 +128,13 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb
return nil, err
}
if curve == Curve_P256 {
sig, err = p256.Normalize(sig)
if err != nil {
return nil, err
}
}
err = c.setSignature(sig)
if err != nil {
return nil, err

View File

@@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/slackhq/nebula/cert/p256"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -89,3 +90,48 @@ func TestCertificateV1_SignP256(t *testing.T) {
require.NoError(t, err)
assert.NotNil(t, uc)
}
func TestCertificate_SignP256_AlwaysNormalized(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
tbs := TBSCertificate{
Version: Version1,
Name: "testing",
Networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
UnsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: true,
Curve: Curve_P256,
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
rawPriv := priv.D.FillBytes(make([]byte, 32))
for i := 0; i < 1000; i++ {
if i&1 == 1 {
tbs.Version = Version1
} else {
tbs.Version = Version2
}
c, err := tbs.Sign(nil, Curve_P256, rawPriv)
require.NoError(t, err)
assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub))
normie, err := p256.IsNormalized(c.Signature())
require.NoError(t, err)
assert.True(t, normie)
}
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"io"
"os"
"strings"
"time"
"github.com/slackhq/nebula/cert"
@@ -40,23 +39,17 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return err
}
rawCACert, err := os.ReadFile(*vf.caPath)
caFile, err := os.Open(*vf.caPath)
if err != nil {
return fmt.Errorf("error while reading ca: %w", err)
}
defer caFile.Close()
caPool := cert.NewCAPool()
for {
rawCACert, err = caPool.AddCAFromPEM(rawCACert)
if err != nil {
caPool, err := cert.NewCAPoolFromPEMReader(caFile)
if err != nil && !errors.Is(err, cert.ErrExpired) {
return fmt.Errorf("error while adding ca cert to pool: %w", err)
}
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
break
}
}
rawCert, err := os.ReadFile(*vf.certPath)
if err != nil {
return fmt.Errorf("unable to read crt: %w", err)

View File

@@ -64,7 +64,7 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
require.ErrorIs(t, err, cert.ErrInvalidPEMBlock)
// make a ca for later
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)

View File

@@ -3,8 +3,15 @@
package main
import "github.com/sirupsen/logrus"
import (
"log/slog"
"os"
func HookLogger(l *logrus.Logger) {
// Do nothing, let the logs flow to stdout/stderr
"github.com/slackhq/nebula/logging"
)
// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows
// platforms have no special sink to integrate with.
func newPlatformLogger() *slog.Logger {
return logging.NewLogger(os.Stdout)
}

View File

@@ -1,54 +1,86 @@
package main
import (
"fmt"
"io/ioutil"
"os"
"context"
"log/slog"
"strings"
"sync"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/logging"
)
// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer
// logrus output will be discarded
func HookLogger(l *logrus.Logger) {
l.AddHook(newLogHook(logger))
l.SetOutput(ioutil.Discard)
// newPlatformLogger returns a *slog.Logger that routes every log record
// through the Windows service logger so records end up in the Windows
// Event Log. All the heavy lifting (level management, format swap,
// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler;
// this file only contributes:
//
// - an io.Writer that forwards each formatted line to the service
// logger at the current record's Event Log severity, and
// - a thin severityTag that embeds *logging.Handler and overrides
// only Handle / WithAttrs / WithGroup, so Event Viewer's severity
// column and severity-based filters keep working the way they did
// before the slog migration.
//
// Format (text vs json) is carried by the embedded *logging.Handler, so
// logging.format: json in config still produces JSON lines in Event
// Viewer, same as the pre-slog logrus setup.
func newPlatformLogger() *slog.Logger {
w := &eventLogWriter{}
return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w})
}
type logHook struct {
sl service.Logger
// eventLogWriter forwards slog-formatted lines to the Windows service
// logger at the severity most recently stashed by severityTag.Handle.
// The mutex serializes the stash + inner.Handle + Write cycle per record
// across all concurrent goroutines; slog's builtin text/json handlers
// each hold their own mutex around Write, but that only protects the
// Write call itself, not our stash-then-handle sequence.
type eventLogWriter struct {
mu sync.Mutex
level slog.Level
}
func newLogHook(sl service.Logger) *logHook {
return &logHook{sl: sl}
}
func (h *logHook) Fire(entry *logrus.Entry) error {
line, err := entry.String()
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
return err
}
switch entry.Level {
case logrus.PanicLevel:
return h.sl.Error(line)
case logrus.FatalLevel:
return h.sl.Error(line)
case logrus.ErrorLevel:
return h.sl.Error(line)
case logrus.WarnLevel:
return h.sl.Warning(line)
case logrus.InfoLevel:
return h.sl.Info(line)
case logrus.DebugLevel:
return h.sl.Info(line)
func (w *eventLogWriter) Write(p []byte) (int, error) {
line := strings.TrimRight(string(p), "\n")
switch {
case w.level >= slog.LevelError:
return len(p), logger.Error(line)
case w.level >= slog.LevelWarn:
return len(p), logger.Warning(line)
default:
return nil
return len(p), logger.Info(line)
}
}
func (h *logHook) Levels() []logrus.Level {
return logrus.AllLevels
// severityTag embeds *logging.Handler to pick up everything it does for
// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat,
// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup
// so each record's slog.Level is stashed on the writer before formatting
// and so derived handlers stay wrapped as severityTag rather than
// downgrading to bare *logging.Handler.
type severityTag struct {
*logging.Handler
w *eventLogWriter
}
func (s *severityTag) Handle(ctx context.Context, r slog.Record) error {
s.w.mu.Lock()
defer s.w.mu.Unlock()
s.w.level = r.Level
return s.Handler.Handle(ctx, r)
}
func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return s
}
return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w}
}
func (s *severityTag) WithGroup(name string) slog.Handler {
if name == "" {
return s
}
return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w}
}

View File

@@ -7,9 +7,9 @@ import (
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/util"
)
@@ -50,10 +50,15 @@ func main() {
os.Exit(0)
}
l := logging.NewLogger(os.Stdout)
if *serviceFlag != "" {
doService(configPath, configTest, Build, serviceFlag)
if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
l.Error("Service command failed", "error", err)
os.Exit(1)
}
return
}
if *configPath == "" {
fmt.Println("-config flag must be set")
@@ -61,9 +66,6 @@ func main() {
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c := config.NewC(l)
err := c.Load(*configPath)
if err != nil {
@@ -71,6 +73,16 @@ func main() {
os.Exit(1)
}
if err := logging.ApplyConfig(l, c); err != nil {
fmt.Printf("failed to apply logging config: %s", err)
os.Exit(1)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
@@ -78,8 +90,20 @@ func main() {
}
if !*configTest {
ctrl.Start()
ctrl.ShutdownBlock()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
if err := wait(); err != nil {
l.Error("Nebula stopped due to fatal error", "error", err)
os.Exit(2)
}
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -7,9 +7,9 @@ import (
"path/filepath"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
)
var logger service.Logger
@@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error {
// Start should not block.
logger.Info("Nebula service starting.")
l := logrus.New()
HookLogger(l)
l := newPlatformLogger()
c := config.NewC(l)
err := c.Load(*p.configPath)
@@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error {
return fmt.Errorf("failed to load config: %s", err)
}
if err := logging.ApplyConfig(l, c); err != nil {
return fmt.Errorf("failed to apply logging config: %s", err)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
if err != nil {
return err
@@ -57,11 +65,11 @@ func fileExists(filename string) bool {
return true
}
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
if *configPath == "" {
ex, err := os.Executable()
if err != nil {
panic(err)
return err
}
*configPath = filepath.Dir(ex) + "/config.yaml"
if !fileExists(*configPath) {
@@ -85,16 +93,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
// Here are what the different loggers are doing:
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use
// - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger
s, err := service.New(prg, svcConfig)
if err != nil {
log.Fatal(err)
return err
}
errs := make(chan error, 5)
logger, err = s.Logger(errs)
if err != nil {
log.Fatal(err)
return err
}
go func() {
@@ -109,18 +117,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
switch *serviceFlag {
case "run":
err = s.Run()
if err != nil {
if err := s.Run(); err != nil {
// Route any errors to the system logger
logger.Error(err)
}
default:
err := service.Control(s, *serviceFlag)
if err != nil {
if err := service.Control(s, *serviceFlag); err != nil {
log.Printf("Valid actions: %q\n", service.ControlAction)
log.Fatal(err)
return err
}
return
}
return nil
}

View File

@@ -7,9 +7,9 @@ import (
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/util"
)
@@ -55,8 +55,7 @@ func main() {
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
l := logging.NewLogger(os.Stdout)
c := config.NewC(l)
err := c.Load(*configPath)
@@ -65,6 +64,16 @@ func main() {
os.Exit(1)
}
if err := logging.ApplyConfig(l, c); err != nil {
fmt.Printf("failed to apply logging config: %s", err)
os.Exit(1)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
@@ -72,9 +81,21 @@ func main() {
}
if !*configTest {
ctrl.Start()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
notifyReady(l)
ctrl.ShutdownBlock()
if err := wait(); err != nil {
l.Error("Nebula stopped due to fatal error", "error", err)
os.Exit(2)
}
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -1,11 +1,10 @@
package main
import (
"log/slog"
"net"
"os"
"time"
"github.com/sirupsen/logrus"
)
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
@@ -13,30 +12,30 @@ import (
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
const SdNotifyReady = "READY=1"
func notifyReady(l *logrus.Logger) {
func notifyReady(l *slog.Logger) {
sockName := os.Getenv("NOTIFY_SOCKET")
if sockName == "" {
l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
return
}
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
if err != nil {
l.WithError(err).Error("failed to connect to systemd notification socket")
l.Error("failed to connect to systemd notification socket", "error", err)
return
}
defer conn.Close()
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
if err != nil {
l.WithError(err).Error("failed to set the write deadline for the systemd notification socket")
l.Error("failed to set the write deadline for the systemd notification socket", "error", err)
return
}
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
l.WithError(err).Error("failed to signal the systemd notification socket")
l.Error("failed to signal the systemd notification socket", "error", err)
return
}
l.Debugln("notified systemd the service is ready")
l.Debug("notified systemd the service is ready")
}

View File

@@ -3,8 +3,8 @@
package main
import "github.com/sirupsen/logrus"
import "log/slog"
func notifyReady(_ *logrus.Logger) {
func notifyReady(_ *slog.Logger) {
// No init service to notify
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"math"
"os"
"os/signal"
@@ -16,7 +17,6 @@ import (
"time"
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"go.yaml.in/yaml/v3"
)
@@ -26,11 +26,11 @@ type C struct {
Settings map[string]any
oldSettings map[string]any
callbacks []func(*C)
l *logrus.Logger
l *slog.Logger
reloadLock sync.Mutex
}
func NewC(l *logrus.Logger) *C {
func NewC(l *slog.Logger) *C {
return &C{
Settings: make(map[string]any),
l: l,
@@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool {
newVals, err := yaml.Marshal(nv)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
c.l.Error("Error while marshaling new config",
"config_path", k,
"error", err,
)
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
c.l.Error("Error while marshaling old config",
"config_path", k,
"error", err,
)
}
return string(newVals) != string(oldVals)
@@ -154,7 +160,10 @@ func (c *C) ReloadConfig() {
err := c.Load(c.path)
if err != nil {
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
c.l.Error("Error occurred while reloading config",
"config_path", c.path,
"error", err,
)
return
}

View File

@@ -5,13 +5,13 @@ import (
"context"
"encoding/binary"
"fmt"
"log/slog"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
@@ -47,10 +47,10 @@ type connectionManager struct {
metricsTxPunchy metrics.Counter
l *logrus.Logger
l *slog.Logger
}
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
cm := &connectionManager{
hostMap: hm,
l: l,
@@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
old := cm.getInactivityTimeout()
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
if !initial {
cm.l.WithField("oldDuration", old).
WithField("newDuration", cm.getInactivityTimeout()).
Info("Inactivity timeout has changed")
cm.l.Info("Inactivity timeout has changed",
"oldDuration", old,
"newDuration", cm.getInactivityTimeout(),
)
}
}
@@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
old := cm.dropInactive.Load()
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
if !initial {
cm.l.WithField("oldBool", old).
WithField("newBool", cm.dropInactive.Load()).
Info("Drop inactive setting has changed")
cm.l.Info("Drop inactive setting has changed",
"oldBool", old,
"newBool", cm.dropInactive.Load(),
)
}
}
}
@@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
var err error
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil {
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
cm.l.Error("failed to migrate relay to new hostinfo", "error", err)
continue
}
switch r.Type {
@@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
msg, err := req.Marshal()
if err != nil {
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
cm.l.Error("failed to marshal Control message to migrate relay", "error", err)
} else {
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
cm.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnAddrs": newhostinfo.vpnAddrs}).
Info("send CreateRelayRequest")
cm.l.Info("send CreateRelayRequest",
"relayFrom", req.RelayFromAddr,
"relayTo", req.RelayToAddr,
"initiatorRelayIndex", req.InitiatorRelayIndex,
"responderRelayIndex", req.ResponderRelayIndex,
"vpnAddrs", newhostinfo.vpnAddrs,
)
}
}
}
@@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
hostinfo := cm.hostMap.Indexes[localIndex]
if hostinfo == nil {
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
cm.l.Debug("Not found in hostmap", "localIndex", localIndex)
return doNothing, nil, nil
}
@@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
// A hostinfo is determined alive if there is incoming traffic
if inTraffic {
decision := doNothing
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Tunnel status",
"tunnelCheck", m{"state": "alive", "method": "passive"},
)
}
hostinfo.pendingDeletion.Store(false)
@@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if hostinfo.pendingDeletion.Load() {
// We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status")
hostinfo.logger(cm.l).Info("Tunnel status",
"tunnelCheck", m{"state": "dead", "method": "active"},
)
return deleteTunnel, hostinfo, nil
}
@@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
if isInactive {
// Tunnel is inactive, tear it down
hostinfo.logger(cm.l).
WithField("inactiveDuration", inactiveFor).
WithField("primary", mainHostInfo).
Info("Dropping tunnel due to inactivity")
hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity",
"inactiveDuration", inactiveFor,
"primary", mainHostInfo,
)
return closeTunnel, hostinfo, primary
}
@@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
cm.sendPunch(hostinfo)
}
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Tunnel status",
"tunnelCheck", m{"state": "testing", "method": "active"},
)
}
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
decision = sendTestPacket
} else {
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Hostinfo sadness")
}
}
@@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
return false //cert is still valid! yay!
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
// Block listed certificates should always be disconnected
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is blocked, tearing down the tunnel")
hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel",
"error", err,
"fingerprint", remoteCert.Fingerprint,
)
return true
} else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel",
"error", err,
"fingerprint", remoteCert.Fingerprint,
)
return true
} else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
@@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
curCrtVersion := curCrt.Version()
myCrt := cs.getCertificate(curCrtVersion)
if myCrt == nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"version", curCrtVersion,
"reason", "local certificate removed",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
@@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version lower than peer, attempting to correct").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"version", curCrtVersion,
"peerVersion", peerCrt.Certificate.Version(),
"reason", "local certificate version lower than peer, attempting to correct",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
@@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
}
}
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"reason", "local certificate is not current",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "current cert version < pki.initiatingVersion").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"reason", "current cert version < pki.initiatingVersion",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return

View File

@@ -10,6 +10,7 @@ import (
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/overlaytest"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
@@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -63,9 +64,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
nb := make([]byte, 12, 12)
@@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -146,9 +147,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
nb := make([]byte, 12, 12)
@@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -231,12 +232,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
conf := config.NewC(test.NewLogger())
conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
@@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -360,9 +361,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.disconnectInvalid.Store(true)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
ifce.connectionManager = nc

View File

@@ -8,7 +8,6 @@ import (
"sync/atomic"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/noiseutil"
)
@@ -27,7 +26,7 @@ type ConnectionState struct {
writeLock sync.Mutex
}
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch crt.Curve() {
case cert.Curve_CURVE25519:

View File

@@ -2,17 +2,33 @@ package nebula
import (
"context"
"errors"
"log/slog"
"net/netip"
"os"
"os/signal"
"sync"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
)
type RunState int
const (
StateUnknown RunState = iota
StateReady
StateStarted
StateStopping
StateStopped
)
var ErrAlreadyStarted = errors.New("nebula is already started")
var ErrAlreadyStopped = errors.New("nebula cannot be restarted")
var ErrUnknownState = errors.New("nebula state is invalid")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -26,8 +42,11 @@ type controlHostLister interface {
}
type Control struct {
stateLock sync.Mutex
state RunState
f *Interface
l *logrus.Logger
l *slog.Logger
ctx context.Context
cancel context.CancelFunc
sshStart func()
@@ -49,10 +68,31 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Start actually runs nebula, this is a nonblocking call.
// The returned function blocks until nebula has fully stopped and returns the
// first fatal reader error (if any). A nil error means nebula shut down
// gracefully; a non-nil error means a reader hit an unexpected failure that
// triggered the shutdown.
func (c *Control) Start() (func() error, error) {
c.stateLock.Lock()
defer c.stateLock.Unlock()
switch c.state {
case StateReady:
//yay!
case StateStopped, StateStopping:
return nil, ErrAlreadyStopped
case StateStarted:
return nil, ErrAlreadyStarted
default:
return nil, ErrUnknownState
}
// Activate the interface
c.f.activate()
err := c.f.activate()
if err != nil {
c.state = StateStopped
return nil, err
}
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
@@ -71,25 +111,51 @@ func (c *Control) Start() {
c.lighthouseStart()
}
c.f.triggerShutdown = c.Stop
// Start reading packets.
c.f.run()
out, err := c.f.run()
if err != nil {
c.state = StateStopped
return nil, err
}
c.state = StateStarted
return out, nil
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
}
func (c *Control) Context() context.Context {
return c.ctx
}
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != StateStarted {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = StateStopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down.
c.cancel()
c.CloseAllTunnels(false)
if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed")
c.l.Error("Close interface failed", "error", err)
}
c.l.Info("Goodbye")
c.stateLock.Lock()
c.state = StateStopped
c.stateLock.Unlock()
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
@@ -100,7 +166,7 @@ func (c *Control) ShutdownBlock() {
rawSig := <-sigChan
sig := rawSig.String()
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
c.l.Info("Caught signal, shutting down", "signal", sig)
c.Stop()
}
@@ -237,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h)
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
c.l.Debug("Sending close tunnel message",
"vpnAddrs", h.vpnAddrs,
"udpAddr", h.remote,
)
closed++
}

View File

@@ -6,7 +6,6 @@ import (
"reflect"
"testing"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
@@ -79,10 +78,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}, &Interface{})
c := Control{
state: StateReady,
f: &Interface{
hostMap: hm,
},
l: logrus.New(),
l: test.NewLogger(),
}
thi := c.GetHostInfoByVpnAddr(vpnIp, false)

View File

@@ -1,5 +1,4 @@
//go:build e2e_testing
// +build e2e_testing
package nebula

View File

@@ -84,30 +84,24 @@ end
function nebula.prefs_changed()
if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then
-- Nothing changed, bail
return
end
-- Remove our old dissector
-- Remove all existing registrations
DissectorTable.get("udp.port"):remove_all(nebula)
if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then
default_settings.all_port = nebula.prefs.all_ports
if nebula.prefs.all_ports then
-- Register on every port for hole punch capture
for i=0, 65535 do
DissectorTable.get("udp.port"):add(i, nebula)
end
-- no need to establish again on specific ports
return
else
-- Register on the configured port only
DissectorTable.get("udp.port"):add(nebula.prefs.port, nebula)
end
if default_settings.all_ports ~= nebula.prefs.all_ports then
-- Add our new port dissector
default_settings.all_ports = nebula.prefs.all_ports
default_settings.port = nebula.prefs.port
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
end
end
DissectorTable.get("udp.port"):add(default_settings.port, nebula)

View File

@@ -1,63 +1,249 @@
package nebula
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
// This whole thing should be rewritten to use context
var dnsR *dnsRecords
var dnsServer *dns.Server
var dnsAddr string
type dnsRecords struct {
type dnsServer struct {
sync.RWMutex
l *logrus.Logger
l *slog.Logger
ctx context.Context
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Lite
mux *dns.ServeMux
// enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`.
// Start, Add, and reload consult it so callers don't need to know the
// gating rules. When it toggles off via reload, accumulated records are
// cleared so a later re-enable starts with a fresh map populated from
// new handshakes.
enabled atomic.Bool
serverMu sync.Mutex
server *dns.Server
// started is closed once `server` has finished binding (or after
// ListenAndServe returns on a bind failure). Stop waits on it before
// calling Shutdown to avoid the miekg/dns "server not started" race
// where a Shutdown that arrives before bind completes is silently
// ignored, leaving the listener running forever.
started chan struct{}
addr string
}
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
return &dnsRecords{
// newDnsServerFromConfig builds a dnsServer, applies the initial config, and
// registers a reload callback. The reload callback is registered before the
// initial config is applied, so a SIGHUP can later enable, fix, or disable
// DNS even if the initial application failed.
//
// The dnsServer internally gates on `lighthouse.serve_dns &&
// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally,
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
// watcher that tears the listener down on nebula shutdown. The returned
// pointer is always non-nil, even on error.
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
ds := &dnsServer{
l: l,
ctx: ctx,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
}
ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest)
c.RegisterReloadCallback(func(c *config.C) {
if err := ds.reload(c, false); err != nil {
ds.l.Error("Failed to reload DNS responder from config", "error", err)
}
})
if err := ds.reload(c, true); err != nil {
return ds, err
}
return ds, nil
}
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
// reload applies the latest config and reconciles the running state with it:
// - enabled toggled on -> spawn a runner
// - enabled toggled off -> stop the runner
// - listen address changed (while running) -> restart on the new address
// - everything else -> no-op
//
// On the initial call it only records configuration; Control.Start is what
// launches the first runner via dnsStart.
func (d *dnsServer) reload(c *config.C, initial bool) error {
wantsDns := c.GetBool("lighthouse.serve_dns", false)
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
enabled := wantsDns && amLighthouse
newAddr := getDnsServerAddr(c)
d.serverMu.Lock()
running := d.server
runningStarted := d.started
sameAddr := d.addr == newAddr
d.addr = newAddr
d.enabled.Store(enabled)
d.serverMu.Unlock()
if initial {
if wantsDns && !amLighthouse {
d.l.Warn("DNS server refusing to run because this host is not a lighthouse.")
}
return nil
}
if !enabled {
if running != nil {
d.Stop()
}
// Drop any records that accumulated while enabled; a later re-enable
// will repopulate from fresh handshakes.
d.clearRecords()
return nil
}
if running == nil {
// Was disabled (or never started); bring it up now.
go d.Start()
return nil
}
if sameAddr {
return nil
}
d.shutdownServer(running, runningStarted, "reload")
// Old Start goroutine has now exited; bring up a fresh listener on the
// new address.
go d.Start()
return nil
}
// shutdownServer waits for the server to finish binding (so Shutdown actually
// stops it rather than no-oping) and then shuts it down.
func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) {
if srv == nil {
return
}
if started != nil {
<-started
}
if err := srv.Shutdown(); err != nil {
d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err)
}
}
// Start binds and serves the DNS responder. Blocks until Stop is called or
// the listener errors. Safe to call when DNS is disabled (returns
// immediately). This is what Control.dnsStart points at.
//
// Must be invoked after the tun device is active so that lighthouse.dns.host
// may bind to a nebula IP.
func (d *dnsServer) Start() {
if !d.enabled.Load() {
return
}
started := make(chan struct{})
d.serverMu.Lock()
if d.ctx.Err() != nil {
d.serverMu.Unlock()
return
}
addr := d.addr
server := &dns.Server{
Addr: addr,
Net: "udp",
Handler: d.mux,
NotifyStartedFunc: func() { close(started) },
}
d.server = server
d.started = started
d.serverMu.Unlock()
// Per-invocation ctx watcher. Exits when Start does, so we don't leak a
// watcher per reload-driven restart.
done := make(chan struct{})
go func() {
select {
case <-d.ctx.Done():
d.shutdownServer(server, started, "shutdown")
case <-done:
}
}()
d.l.Info("Starting DNS responder", "dnsListener", addr)
err := server.ListenAndServe()
close(done)
// If the listener never bound (bind error) NotifyStartedFunc never fires,
// so close started here to release any Stop caller waiting on it.
select {
case <-started:
default:
close(started)
}
if err != nil {
d.l.Warn("Failed to run the DNS responder", "error", err)
}
}
// Stop shuts down the active server, if any. Idempotent.
func (d *dnsServer) Stop() {
d.serverMu.Lock()
srv := d.server
started := d.started
d.server = nil
d.started = nil
d.serverMu.Unlock()
d.shutdownServer(srv, started, "stop")
}
// Query returns the address for the given name and query type. The second
// return value reports whether the name is known at all (in either A or AAAA),
// which lets callers distinguish NODATA from NXDOMAIN.
func (d *dnsServer) Query(q uint16, data string) (netip.Addr, bool) {
data = strings.ToLower(data)
d.RLock()
defer d.RUnlock()
addr4, haveV4 := d.dnsMap4[data]
addr6, haveV6 := d.dnsMap6[data]
nameExists := haveV4 || haveV6
switch q {
case dns.TypeA:
if r, ok := d.dnsMap4[data]; ok {
return r
if haveV4 {
return addr4, nameExists
}
case dns.TypeAAAA:
if r, ok := d.dnsMap6[data]; ok {
return r
if haveV6 {
return addr6, nameExists
}
}
return netip.Addr{}
return netip.Addr{}, nameExists
}
func (d *dnsRecords) QueryCert(data string) string {
func (d *dnsServer) QueryCert(data string) string {
if len(data) < 2 {
return ""
}
ip, err := netip.ParseAddr(data[:len(data)-1])
if err != nil {
return ""
@@ -80,8 +266,19 @@ func (d *dnsRecords) QueryCert(data string) string {
return string(b)
}
// clearRecords drops all DNS records.
func (d *dnsServer) clearRecords() {
d.Lock()
defer d.Unlock()
clear(d.dnsMap4)
clear(d.dnsMap6)
}
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
func (d *dnsServer) Add(host string, addresses []netip.Addr) {
if !d.enabled.Load() {
return
}
host = strings.ToLower(host)
d.Lock()
defer d.Unlock()
@@ -101,7 +298,7 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
}
}
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
a, _, _ := net.SplitHostPort(addr)
b, err := netip.ParseAddr(a)
if err != nil {
@@ -116,13 +313,24 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
return d.myVpnAddrsTable.Contains(b)
}
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug)
// Per RFC 2308 §2.2, a name that exists but has no record of the requested
// type must be answered with NOERROR and an empty answer section (NODATA),
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
// exist at all.
anyNameExists := false
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name)
ip := d.Query(q.Qtype, q.Name)
if debugEnabled {
d.l.Debug("DNS query", "type", qType, "name", q.Name)
}
ip, nameExists := d.Query(q.Qtype, q.Name)
if nameExists {
anyNameExists = true
}
if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil {
@@ -134,7 +342,9 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return
}
d.l.Debugf("Query for TXT %s", q.Name)
if debugEnabled {
d.l.Debug("DNS query", "type", "TXT", "name", q.Name)
}
ip := d.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
@@ -145,12 +355,12 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
}
}
if len(m.Answer) == 0 {
if len(m.Answer) == 0 && !anyNameExists {
m.Rcode = dns.RcodeNameError
}
}
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
@@ -163,21 +373,6 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m)
}
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(l, cs, hostMap)
// attach request handler func
dns.HandleFunc(".", dnsR.handleDnsRequest)
c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c)
})
return func() {
startDns(l, c)
}
}
func getDnsServerAddr(c *config.C) string {
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
@@ -186,25 +381,3 @@ func getDnsServerAddr(c *config.C) string {
}
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
}
func startDns(l *logrus.Logger, c *config.C) {
dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
err := dnsServer.ListenAndServe()
defer dnsServer.Shutdown()
if err != nil {
l.Errorf("Failed to start server: %s\n ", err.Error())
}
}
func reloadDns(l *logrus.Logger, c *config.C) {
if dnsAddr == getDnsServerAddr(c) {
l.Debug("No DNS server config change detected")
return
}
l.Debug("Restarting DNS server")
dnsServer.Shutdown()
go startDns(l, c)
}

View File

@@ -1,19 +1,43 @@
package nebula
import (
"context"
"log/slog"
"net"
"net/netip"
"strconv"
"testing"
"time"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type stubDNSWriter struct{}
func (stubDNSWriter) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (stubDNSWriter) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5353}
}
func (stubDNSWriter) Write([]byte) (int, error) { return 0, nil }
func (stubDNSWriter) WriteMsg(*dns.Msg) error { return nil }
func (stubDNSWriter) Close() error { return nil }
func (stubDNSWriter) TsigStatus() error { return nil }
func (stubDNSWriter) TsigTimersOnly(bool) {}
func (stubDNSWriter) Hijack() {}
func TestParsequery(t *testing.T) {
l := logrus.New()
l := slog.New(slog.DiscardHandler)
hostMap := &HostMap{}
ds := newDnsRecords(l, &CertState{}, hostMap)
ds := &dnsServer{
l: l,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
}
ds.enabled.Store(true)
addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"),
@@ -21,18 +45,56 @@ func TestParsequery(t *testing.T) {
netip.MustParseAddr("fd01::25"),
}
ds.Add("test.com.com", addrs)
ds.Add("v4only.com.com", []netip.Addr{netip.MustParseAddr("1.2.3.6")})
ds.Add("v6only.com.com", []netip.Addr{netip.MustParseAddr("fd01::26")})
m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
// A known name with no record of the requested type should return NODATA
// (NOERROR with empty answer), not NXDOMAIN.
m = &dns.Msg{}
m.SetQuestion("v4only.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
m = &dns.Msg{}
m.SetQuestion("v6only.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
// An unknown name should still return NXDOMAIN.
m = &dns.Msg{}
m.SetQuestion("unknown.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeNameError, m.Rcode)
// short lookups should not fail
m = &dns.Msg{}
m.Question = []dns.Question{{Name: "", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
ds.parseQuery(m, stubDNSWriter{})
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeNameError, m.Rcode)
m = &dns.Msg{}
m.Question = []dns.Question{{Name: ".", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
ds.parseQuery(m, stubDNSWriter{})
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeNameError, m.Rcode)
}
func Test_getDnsServerAddr(t *testing.T) {
@@ -71,3 +133,208 @@ func Test_getDnsServerAddr(t *testing.T) {
}
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
}
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
t.Helper()
sl := slog.New(slog.DiscardHandler)
ds := &dnsServer{
l: sl,
ctx: context.Background(),
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: &HostMap{},
}
ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest)
return ds, config.NewC(nil)
}
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
c.Settings["lighthouse"] = map[string]any{
"am_lighthouse": amLighthouse,
"serve_dns": serveDns,
"dns": map[string]any{
"host": host,
"port": port,
},
}
}
func TestDnsServer_reload_initial_disabled(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", true, false)
require.NoError(t, ds.reload(c, true))
assert.False(t, ds.enabled.Load())
assert.Equal(t, "127.0.0.1:0", ds.addr)
assert.Nil(t, ds.server)
}
func TestDnsServer_reload_initial_enabled(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
assert.True(t, ds.enabled.Load())
assert.Equal(t, "127.0.0.1:0", ds.addr)
// initial never starts a runner; that's Control.Start's job
assert.Nil(t, ds.server)
}
func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", false, true)
require.NoError(t, ds.reload(c, true))
// Wants DNS but isn't a lighthouse: gated off, no runner.
assert.False(t, ds.enabled.Load())
}
func TestDnsServer_reload_sameAddr_noOp(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
// No server running yet, no addr change. Reload should not spawn anything.
require.NoError(t, ds.reload(c, false))
assert.True(t, ds.enabled.Load())
assert.Nil(t, ds.server)
}
func TestDnsServer_StartStop_lifecycle(t *testing.T) {
// Bind to a real (random) UDP port so we exercise the actual
// ListenAndServe + Shutdown plumbing including the started-chan race fix.
port := freeUDPPort(t)
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", port, true, true)
require.NoError(t, ds.reload(c, true))
done := make(chan struct{})
go func() {
ds.Start()
close(done)
}()
waitFor(t, func() bool {
ds.serverMu.Lock()
started := ds.started
ds.serverMu.Unlock()
if started == nil {
return false
}
select {
case <-started:
return true
default:
return false
}
})
ds.Stop()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Start did not return after Stop")
}
}
func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
// Stop called immediately after Start should not deadlock even if bind
// hasn't completed yet. This exercises the started-chan close-on-bind-fail
// path: by binding to an obviously bad port (privileged) we get a fast
// bind error before NotifyStartedFunc fires.
ds, c := newTestDnsServer(t)
// Use a port that should fail to bind (negative would be invalid, use a
// host that won't resolve to ensure listenUDP fails quickly).
setDnsConfig(c, "256.256.256.256", "53", true, true)
require.NoError(t, ds.reload(c, true))
done := make(chan struct{})
go func() {
ds.Start()
close(done)
}()
// Give Start a moment to attempt the bind and fail.
select {
case <-done:
// Bind failed and Start returned; Stop should be a no-op.
case <-time.After(time.Second):
t.Fatal("Start did not return after a bad bind")
}
stopped := make(chan struct{})
go func() {
ds.Stop()
close(stopped)
}()
select {
case <-stopped:
case <-time.After(time.Second):
t.Fatal("Stop hung after a failed bind")
}
}
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
port := freeUDPPort(t)
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", port, true, true)
require.NoError(t, ds.reload(c, true))
startReturned := make(chan struct{})
go func() {
ds.Start()
close(startReturned)
}()
waitForBind(t, ds)
// Toggle serve_dns off; reload should shut the running server down.
setDnsConfig(c, "127.0.0.1", port, true, false)
require.NoError(t, ds.reload(c, false))
select {
case <-startReturned:
case <-time.After(5 * time.Second):
t.Fatal("Start did not return after reload disabled DNS")
}
assert.False(t, ds.enabled.Load())
}
func freeUDPPort(t *testing.T) string {
t.Helper()
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
port := conn.LocalAddr().(*net.UDPAddr).Port
require.NoError(t, conn.Close())
return strconv.Itoa(port)
}
func waitForBind(t *testing.T, ds *dnsServer) {
t.Helper()
waitFor(t, func() bool {
ds.serverMu.Lock()
started := ds.started
ds.serverMu.Unlock()
if started == nil {
return false
}
select {
case <-started:
return true
default:
return false
}
})
}
func waitFor(t *testing.T, cond func() bool) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatal("timed out waiting for condition")
}

View File

@@ -0,0 +1,565 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
)
// makeHandshakePacket creates a handshake packet with the given parameters.
func makeHandshakePacket(from, to netip.AddrPort, subtype header.MessageSubType, remoteIndex uint32, counter uint64) *udp.Packet {
data := make([]byte, 200)
header.Encode(data, header.Version, header.Handshake, subtype, remoteIndex, counter)
for i := header.Len; i < len(data); i++ {
data[i] = byte(i)
}
return &udp.Packet{To: to, From: from, Data: data}
}
func TestHandshakeRetransmitDuplicate(t *testing.T) {
// Verify the responder correctly handles receiving the same msg1 multiple times
// (retransmission). The duplicate goes through CheckAndComplete -> ErrAlreadySeen
// and the cached response is resent.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Trigger handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
t.Log("Grab my msg1")
msg1 := myControl.GetFromUDP(true)
t.Log("Inject msg1 into them, first time")
theirControl.InjectUDPPacket(msg1)
_ = theirControl.GetFromUDP(true)
t.Log("Inject the SAME msg1 again, tests ErrAlreadySeen path")
theirControl.InjectUDPPacket(msg1)
resp2 := theirControl.GetFromUDP(true)
assert.NotNil(t, resp2, "should get cached response on duplicate msg1")
t.Log("Complete handshake with cached response")
myControl.InjectUDPPacket(resp2)
myControl.WaitForType(1, 0, theirControl)
t.Log("Drain cached packet and verify tunnel works")
cachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Verify only one tunnel exists on each side")
assert.Len(t, myControl.ListHostmapHosts(false), 1)
assert.Len(t, theirControl.ListHostmapHosts(false), 1)
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeTruncatedPacketRecovery(t *testing.T) {
// Verify that a truncated handshake packet is ignored and the real
// packet can still complete the handshake.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Trigger handshake")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
t.Log("Get msg1 and deliver to responder")
msg1 := myControl.GetFromUDP(true)
theirControl.InjectUDPPacket(msg1)
t.Log("Get the real response")
realResp := theirControl.GetFromUDP(true)
t.Log("Truncate the response and inject, should be ignored")
truncResp := realResp.Copy()
truncResp.Data = truncResp.Data[:header.Len]
myControl.InjectUDPPacket(truncResp)
t.Log("Verify pending handshake survived the truncated packet")
assert.NotEmpty(t, myControl.ListHostmapHosts(true), "pending handshake should still exist")
t.Log("Inject real response, should complete handshake")
myControl.InjectUDPPacket(realResp)
myControl.WaitForType(1, 0, theirControl)
t.Log("Drain and verify tunnel")
cachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeOrphanedMsg2Dropped(t *testing.T) {
// A msg2 arriving with no matching pending index should be silently dropped
// with no response sent and no state changes.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Complete a normal handshake")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
r.RouteForAllUntilTxTun(theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Record hostmap state")
myIndexes := len(myControl.ListHostmapIndexes(false))
t.Log("Inject a fake msg2 with unknown RemoteIndex")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0xDEADBEEF, 2))
t.Log("Verify no new indexes created")
assert.Equal(t, myIndexes, len(myControl.ListHostmapIndexes(false)))
t.Log("Verify no UDP response was sent")
time.Sleep(100 * time.Millisecond)
assert.Nil(t, myControl.GetFromUDP(false), "should not send a response to orphaned msg2")
t.Log("Verify existing tunnel still works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeUnknownMessageCounter(t *testing.T) {
// A handshake packet with an unexpected message counter should be silently
// dropped with no side effects and no UDP response.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
myControl.Start()
theirControl.Start()
t.Log("Inject handshake with MessageCounter=3")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 3))
t.Log("Inject handshake with MessageCounter=99")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 99))
t.Log("Verify no tunnels or pending handshakes")
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
t.Log("Verify no UDP response was sent")
time.Sleep(100 * time.Millisecond)
assert.Nil(t, myControl.GetFromUDP(false))
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeUnknownSubtype(t *testing.T) {
// A handshake packet with an unknown subtype should be silently dropped.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.Start()
theirControl.Start()
t.Log("Inject handshake with unknown subtype 99")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.MessageSubType(99), 0, 1))
t.Log("Verify no tunnels or pending handshakes")
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
t.Log("Verify no UDP response was sent")
time.Sleep(100 * time.Millisecond)
assert.Nil(t, myControl.GetFromUDP(false))
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeLateResponse(t *testing.T) {
// After a handshake times out, a late response should be silently ignored
// with no new tunnels created.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{
"handshakes": m{
"try_interval": "200ms",
"retries": 2,
},
})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
myControl.Start()
theirControl.Start()
t.Log("Trigger handshake from me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
t.Log("Grab msg1 but don't deliver")
msg1 := myControl.GetFromUDP(true)
t.Log("Wait for handshake to time out")
for i := 0; i < 5; i++ {
time.Sleep(300 * time.Millisecond)
myControl.GetFromUDP(false)
}
t.Log("Confirm no pending handshakes remain")
assert.Empty(t, myControl.ListHostmapHosts(true))
t.Log("Deliver old msg1 to them, they create a tunnel")
theirControl.InjectUDPPacket(msg1)
resp := theirControl.GetFromUDP(true)
assert.NotNil(t, resp)
t.Log("Inject late response into me, should be ignored")
myControl.InjectUDPPacket(resp)
t.Log("No tunnel should exist on my side")
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeSelfConnectionRejected(t *testing.T) {
// Verify that a node rejects a handshake containing its own VPN IP in the
// peer cert. We do this by sending the initiator's own msg1 back to itself.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
// Need a lighthouse entry to trigger a handshake
myControl.InjectLightHouseAddr(netip.MustParseAddr("10.128.0.2"), netip.MustParseAddrPort("10.0.0.2:4242"))
myControl.Start()
t.Log("Trigger handshake from me")
myControl.InjectTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
msg1 := myControl.GetFromUDP(true)
t.Log("Drain any handshake retransmits before injecting")
time.Sleep(100 * time.Millisecond)
for myControl.GetFromUDP(false) != nil {
}
t.Log("Feed my own msg1 back to me as if it came from someone else")
selfMsg := msg1.Copy()
selfMsg.From = netip.MustParseAddrPort("10.0.0.99:4242")
selfMsg.To = myUdpAddr
myControl.InjectUDPPacket(selfMsg)
t.Log("Verify no response was sent (self-connection rejected)")
time.Sleep(100 * time.Millisecond)
// Drain any further retransmits from the original handshake, then check
// that none of them are a handshake response (MessageCounter=2)
h := &header.H{}
for {
p := myControl.GetFromUDP(false)
if p == nil {
break
}
_ = h.Parse(p.Data)
assert.NotEqual(t, uint64(2), h.MessageCounter,
"should not send a stage 2 response to self-connection")
}
t.Log("Verify no tunnel to myself was created")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false))
myControl.Stop()
}
func TestHandshakeMessageCounter0Dropped(t *testing.T) {
// MessageCounter=0 is not a valid handshake message and should be dropped.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
_, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.Start()
t.Log("Inject handshake with MessageCounter=0")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 0))
time.Sleep(100 * time.Millisecond)
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
assert.Nil(t, myControl.GetFromUDP(false))
myControl.Stop()
}
func TestHandshakeRemoteAllowList(t *testing.T) {
// Verify that a handshake from a blocked underlay IP is dropped with no
// response and no state changes. Then verify the same packet from an
// allowed IP succeeds.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{
"lighthouse": m{
"remote_allow_list": m{
"10.0.0.0/8": true,
"0.0.0.0/0": false,
},
},
})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Trigger handshake from them")
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi"))
msg1 := theirControl.GetFromUDP(true)
t.Log("Rewrite the source to a blocked IP and inject")
blockedMsg := msg1.Copy()
blockedMsg.From = netip.MustParseAddrPort("192.168.1.1:4242")
myControl.InjectUDPPacket(blockedMsg)
t.Log("Verify no tunnel, no pending, no response from blocked source")
time.Sleep(100 * time.Millisecond)
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
assert.Nil(t, myControl.GetFromUDP(false), "should not respond to blocked source")
t.Log("Now inject the real packet from the allowed source")
myControl.InjectUDPPacket(msg1)
t.Log("Verify handshake completes from allowed source")
resp := myControl.GetFromUDP(true)
assert.NotNil(t, resp)
theirControl.InjectUDPPacket(resp)
theirControl.WaitForType(1, 0, myControl)
t.Log("Drain cached packet and verify tunnel works")
cachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi"), cachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) {
// When a duplicate msg1 arrives via ErrAlreadySeen, verify the tunnel
// remains functional and hostmap index count is stable.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Complete a normal handshake via the router")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
r.RouteForAllUntilTxTun(theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Record hostmap state")
theirIndexes := len(theirControl.ListHostmapIndexes(false))
hi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.NotNil(t, hi)
originalRemote := hi.CurrentRemote
t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam"))
r.RouteForAllUntilTxTun(theirControl)
t.Log("Verify tunnel still works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Verify remote is still valid and index count is stable")
hi2 := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.NotNil(t, hi2)
assert.Equal(t, originalRemote, hi2.CurrentRemote)
assert.Equal(t, theirIndexes, len(theirControl.ListHostmapIndexes(false)),
"no extra indexes should be created from ErrAlreadySeen")
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeWrongResponderPacketStore(t *testing.T) {
// Verify that when the wrong host responds, the cached packets are
// transferred to the new handshake, the evil tunnel is closed, evil's
// address is blocked, and the correct tunnel is eventually established.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
evilControl, evilVpnIpNet, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
r := router.NewR(t, myControl, theirControl, evilControl)
defer r.RenderFlow()
myControl.Start()
theirControl.Start()
evilControl.Start()
t.Log("Send multiple packets to them (cached during handshake)")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2"))
t.Log("Route until evil tunnel is closed")
h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
if err := h.Parse(p.Data); err != nil {
panic(err)
}
if h.Type == header.CloseTunnel && p.To == evilUdpAddr {
return router.RouteAndExit
}
return router.KeepRouting
})
t.Log("Verify evil's address is blocked in the new pending handshake")
pendingHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
if pendingHI != nil {
assert.NotContains(t, pendingHI.RemoteAddrs, evilUdpAddr,
"evil's address should be blocked")
}
t.Log("Inject correct lighthouse addr for them")
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
t.Log("Route until cached packets arrive at the real them")
p := r.RouteForAllUntilTxTun(theirControl)
assert.NotNil(t, p, "a cached packet should be delivered to the correct host")
t.Log("Verify the correct host has a tunnel")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Verify no hostinfo artifacts from evil remain")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), true),
"no pending hostinfo for evil")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), false),
"no main hostinfo for evil")
myControl.Stop()
theirControl.Stop()
evilControl.Stop()
}
func TestHandshakeRelayComplete(t *testing.T) {
// Verify that a relay handshake completes correctly and relay state is
// properly maintained on all three nodes.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger handshake via relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay"))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Verify bidirectional tunnel via relay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Verify relay state on my side shows relay-to-me")
myHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
assert.NotNil(t, myHI)
assert.NotEmpty(t, myHI.CurrentRelaysToMe, "should have relay-to-me for them")
t.Log("Verify relay state on their side shows relay-to-me")
theirHI := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.NotNil(t, theirHI)
assert.NotEmpty(t, theirHI.CurrentRelaysToMe, "should have relay-to-me for me")
t.Log("Verify relay node shows through-me relays")
relayHI := relayControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.NotNil(t, relayHI)
myControl.Stop()
relayControl.Stop()
theirControl.Stop()
}
// NOTE: Relay V1 cert + IPv6 rejection is not tested here because
// InjectTunUDPPacket from a V4 node to a V6 address panics in the test
// framework. The check is in handshake_manager.go handleOutbound relay
// logic (lines ~304-313): if the relay host has a V1 cert and either
// address is IPv6, the relay is skipped.
// NOTE: Relay reestablishment (Disestablished state transition) is covered
// by the existing TestReestablishRelays in handshakes_test.go.

View File

@@ -11,7 +11,6 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
@@ -749,7 +748,6 @@ func TestStage1RaceRelays2(t *testing.T) {
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
l := NewTestLogger()
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
@@ -771,49 +769,41 @@ func TestStage1RaceRelays2(t *testing.T) {
theirControl.Start()
r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
r.Log("Wait for a packet from them to me")
l.Info("Wait for a packet from them to me; myControl")
r.Log("Wait for a packet from them to me; myControl")
r.RouteForAllUntilTxTun(myControl)
l.Info("Wait for a packet from them to me; theirControl")
r.Log("Wait for a packet from them to me; theirControl")
r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels")
l.WithFields(
logrus.Fields{
"myControl": len(myControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
len(myControl.GetHostmap().Indexes),
len(theirControl.GetHostmap().Indexes),
len(relayControl.GetHostmap().Indexes),
)
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
retries := 60
for hostInfos > 6 && retries > 0 {
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
l.WithFields(
logrus.Fields{
"myControl": len(myControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
len(myControl.GetHostmap().Indexes),
len(theirControl.GetHostmap().Indexes),
len(relayControl.GetHostmap().Indexes),
)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
@@ -821,7 +811,6 @@ func TestStage1RaceRelays2(t *testing.T) {
}
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
myControl.Stop()
@@ -1369,6 +1358,81 @@ func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
theirControl.Stop()
}
func TestLighthouseUpdateOnReload(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
// Create the lighthouse
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh", "10.128.0.1/24", m{"lighthouse": m{"am_lighthouse": true}})
// Create a client with NO lighthouse configured and a long update interval.
// The initial SendUpdate at startup will be a no-op since no lighthouses are known.
myControl, myVpnIpNet, _, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.2/24", m{
"lighthouse": m{
"interval": 600,
"local_allow_list": m{
"10.0.0.0/24": true,
"::/0": false,
},
},
})
r := router.NewR(t, lhControl, myControl)
defer r.RenderFlow()
lhControl.Start()
myControl.Start()
// Drain any startup packets (there should be none meaningful)
r.FlushAll()
// Verify lighthouse has no knowledge of the client
assert.Nil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr()))
// Build a new config that adds the lighthouse
newSettings := make(m)
for k, v := range myConfig.Settings {
newSettings[k] = v
}
newSettings["static_host_map"] = m{
lhVpnIpNet[0].Addr().String(): []any{lhUdpAddr.String()},
}
newSettings["lighthouse"] = m{
"hosts": []any{lhVpnIpNet[0].Addr().String()},
"interval": 600,
"local_allow_list": m{
"10.0.0.0/24": true,
"::/0": false,
},
}
newCfg, err := yaml.Marshal(newSettings)
require.NoError(t, err)
// Reload the config. The lighthouse.hosts change triggers TriggerUpdate,
// which wakes the update worker. It calls SendUpdate, initiating a
// handshake to the new lighthouse and caching the HostUpdateNotification.
require.NoError(t, myConfig.ReloadConfigString(string(newCfg)))
// Route until the lighthouse receives the HostUpdateNotification.
// This covers: handshake stage 1, stage 2, then the cached update.
done := make(chan struct{})
go func() {
r.RouteForAllUntilAfterMsgTypeTo(lhControl, header.LightHouse, 0)
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for lighthouse update after config reload")
}
// Verify lighthouse now has the client's addresses
assert.NotNil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr()))
r.RenderHostmaps("Final hostmaps", lhControl, myControl)
lhControl.Stop()
myControl.Stop()
}
func TestGoodHandshakeUnsafeDest(t *testing.T) {
unsafePrefix := "192.168.6.0/24"
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})

View File

@@ -4,7 +4,6 @@
package e2e
import (
"fmt"
"io"
"net/netip"
"os"
@@ -12,15 +11,18 @@ import (
"testing"
"time"
"log/slog"
"dario.cat/mergo"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3"
@@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
"level": l.Level.String(),
"level": testLogLevelName(),
},
"timers": m{
"pending_deletion_interval": 2,
@@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
"level": l.Level.String(),
"level": testLogLevelName(),
},
"timers": m{
"pending_deletion_interval": 2,
@@ -379,24 +379,32 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
return a
}
func NewTestLogger() *logrus.Logger {
l := logrus.New()
func NewTestLogger() *slog.Logger {
v := os.Getenv("TEST_LOGS")
if v == "" {
l.SetOutput(io.Discard)
l.SetLevel(logrus.PanicLevel)
return l
return slog.New(slog.NewTextHandler(io.Discard, nil))
}
level := slog.LevelInfo
switch v {
case "2":
l.SetLevel(logrus.DebugLevel)
level = slog.LevelDebug
case "3":
l.SetLevel(logrus.TraceLevel)
default:
l.SetLevel(logrus.InfoLevel)
level = logging.LevelTrace
}
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
}
return l
// testLogLevelName returns the level name string accepted by logging.ApplyConfig
// for the current TEST_LOGS setting. Kept in sync with NewTestLogger.
func testLogLevelName() string {
switch os.Getenv("TEST_LOGS") {
case "2":
return "debug"
case "3":
return "trace"
case "":
return "info"
}
return "info"
}

View File

@@ -12,6 +12,8 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
@@ -365,3 +367,106 @@ func TestCrossStackRelaysWork(t *testing.T) {
//theirControl.Stop()
//relayControl.Stop()
}
func TestCloseTunnelAuthenticated(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Close the tunnel")
myControl.CloseTunnel(theirVpnIpNet[0].Addr(), false)
r.FlushAll()
waitStart := time.Now()
for {
myIndexes := len(myControl.GetHostmap().Indexes)
theirIndexes := len(theirControl.GetHostmap().Indexes)
if myIndexes == 0 && theirIndexes == 0 {
break
}
since := time.Since(waitStart)
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
if since > time.Second*6 {
t.Fatal("Tunnel should have been declared inactive after 2 seconds and before 6 seconds")
}
time.Sleep(1 * time.Second)
//r.FlushAll()
}
r.Logf("Happy path success, tunnels were dropped within %v", time.Since(waitStart))
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
r.Log("Assert another tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
hi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if hi == nil {
t.Fatal("There is no hostinfo for this tunnel")
}
myHi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if myHi == nil {
t.Fatal("There is no hostinfo for my tunnel")
}
r.Log("It does")
buf := make([]byte, 1024)
hdr := header.H{
Version: 1,
Type: header.CloseTunnel,
Subtype: 0,
Reserved: 0,
RemoteIndex: hi.RemoteIndex,
MessageCounter: 5,
}
out, err := hdr.Encode(buf)
if err != nil {
t.Fatal(err)
}
pkt := &udp.Packet{
To: hi.CurrentRemote,
From: myHi.CurrentRemote,
Data: out,
}
r.InjectUDPPacket(myControl, theirControl, pkt)
r.Log("Injected bogus close tunnel. Let's see!")
waitStart = time.Now()
for {
myIndexes := len(myControl.GetHostmap().Indexes)
theirIndexes := len(theirControl.GetHostmap().Indexes)
if myIndexes == 0 {
t.Fatal("myIndexes should not be 0")
}
if theirIndexes == 0 {
t.Fatal("theirIndexes should not be 0, they should have rejected this bogus packet")
}
since := time.Since(waitStart)
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
if since > time.Second*4 {
t.Log("The tunnel would have been gone by now")
break
}
time.Sleep(1 * time.Second)
r.FlushAll()
}
myControl.Stop()
theirControl.Stop()
}

View File

@@ -204,6 +204,12 @@ punchy:
# Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access.
#trusted_cas:
#- "ssh public key string"
# sandbox_dir restricts file paths for profiling commands (start-cpu-profile, save-heap-profile,
# save-mutex-profile) to the specified directory. Relative paths will be resolved within this directory,
# and absolute paths outside of it will be rejected. Default is $TMP/nebula-debug.
# The directory is NOT automatically created.
# Overriding this to "" is the same as "/" and will allow overwriting any path on the host.
#sandbox_dir: /var/tmp/nebula-debug
# EXPERIMENTAL: relay support for networks that can't establish direct connections.
relay:
@@ -286,24 +292,21 @@ tun:
# Configure logging level
logging:
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
#NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some
# scenarios. Debug logging is also CPU intensive and will decrease performance overall.
# Only enable debug logging while actively investigating an issue.
# trace, debug, info, warn, or error. Default is info and is reloadable.
# fatal and panic are accepted for backwards compatibility and map to error.
#NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some
# scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall.
# Only enable debug or trace logging while actively investigating an issue.
level: info
# json or text formats currently available. Default is text
# json or text formats currently available. Default is text.
format: text
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
# Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false.
#disable_timestamp: true
# timestamp format is specified in Go time format, see:
# https://golang.org/pkg/time/#pkg-constants
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
# default when `format: text`:
# when TTY attached: seconds since beginning of execution
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
# As an example, to log as RFC3339 with millisecond precision, set to:
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
# Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable.
# The stats section is reloadable. A HUP may change the backend, toggle stats
# on or off, switch the listen/host address, or pick up new DNS for the
# configured graphite host.
#stats:
#type: graphite
#prefix: nebula
@@ -321,10 +324,12 @@ logging:
# enables counter metrics for meta packets
# e.g.: `messages.tx.handshake`
# NOTE: `message.{tx,rx}.recv_error` is always emitted
# Not reloadable.
#message_metrics: false
# enables detailed counter metrics for lighthouse packets
# e.g.: `lighthouse.rx.HostQuery`
# Not reloadable.
#lighthouse_metrics: false
# Handshake Manager Settings
@@ -382,8 +387,8 @@ firewall:
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
# proto: `any`, `tcp`, `udp`, or `icmp`
# a port specification is ignored if proto is `icmp`
# host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass

View File

@@ -7,9 +7,9 @@ import (
"net"
"os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/service"
)
@@ -64,8 +64,7 @@ pki:
return err
}
logger := logrus.New()
logger.Out = os.Stdout
logger := logging.NewLogger(os.Stdout)
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {

View File

@@ -1,11 +1,13 @@
package nebula
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash/fnv"
"log/slog"
"net/netip"
"reflect"
"slices"
@@ -16,7 +18,6 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
@@ -67,7 +68,7 @@ type Firewall struct {
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
l *logrus.Logger
l *slog.Logger
}
type firewallMetrics struct {
@@ -131,7 +132,7 @@ type firewallLocalCIDR struct {
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration
var tmin, tmax time.Duration
@@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
@@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
case "drop":
fw.InSendReject = false
default:
l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`")
l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction)
fw.InSendReject = false
}
@@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
case "drop":
fw.OutSendReject = false
default:
l.WithField("action", inboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`")
l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction)
fw.OutSendReject = false
}
@@ -249,20 +250,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
)
f.rules += ruleString + "\n"
direction := "incoming"
if !incoming {
direction = "outgoing"
}
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
ft *FirewallTable
fp firewallPort
@@ -280,6 +267,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
case firewall.ProtoUDP:
fp = ft.UDP
case firewall.ProtoICMP, firewall.ProtoICMPv6:
//ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided
if startPort != firewall.PortAny {
f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort)
}
startPort = firewall.PortAny
endPort = firewall.PortAny
fp = ft.ICMP
case firewall.ProtoAny:
fp = ft.AnyProto
@@ -287,6 +280,21 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto)
}
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
)
f.rules += ruleString + "\n"
direction := "incoming"
if !incoming {
direction = "outgoing"
}
f.l.Info("Firewall rule added",
"firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha},
)
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
}
@@ -308,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
}
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string
if inbound {
table = "firewall.inbound"
@@ -349,24 +357,31 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
sPort = r.Port
}
startPort, endPort, err := parsePort(sPort)
if err != nil {
return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
}
var proto uint8
var startPort, endPort int32
switch r.Proto {
case "any":
proto = firewall.ProtoAny
startPort, endPort, err = parsePort(sPort)
case "tcp":
proto = firewall.ProtoTCP
startPort, endPort, err = parsePort(sPort)
case "udp":
proto = firewall.ProtoUDP
startPort, endPort, err = parsePort(sPort)
case "icmp":
proto = firewall.ProtoICMP
startPort = firewall.PortAny
endPort = firewall.PortAny
if sPort != "" {
l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort)
}
default:
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
if err != nil {
return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
}
if r.Cidr != "" && r.Cidr != "any" {
_, err = netip.ParsePrefix(r.Cidr)
@@ -383,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
}
if warning := r.sanity(); warning != nil {
l.Warnf("%s rule #%v; %s", table, i, warning)
l.Warn("firewall rule sanity check",
"table", table,
"rule", i,
"warning", warning,
)
}
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
@@ -467,7 +486,7 @@ func (f *Firewall) metrics(incoming bool) firewallMetrics {
}
}
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
// Destroy cleans up any known cyclical references so the object can be freed by GC. This should be called if a new
// firewall object is created
func (f *Firewall) Destroy() {
//TODO: clean references if/when needed
@@ -515,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
// We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset",
"fwPacket", fp,
"incoming", c.incoming,
"rulesVersion", f.rulesVersion,
"oldRulesVersion", c.rulesVersion,
)
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
}
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset",
"fwPacket", fp,
"incoming", c.incoming,
"rulesVersion", f.rulesVersion,
"oldRulesVersion", c.rulesVersion,
)
}
c.rulesVersion = f.rulesVersion
@@ -660,6 +679,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
return false
}
// this branch is here to catch traffic from FirewallTable.Any.match and FirewallTable.ICMP.match
if p.Protocol == firewall.ProtoICMP || p.Protocol == firewall.ProtoICMPv6 {
// port numbers are re-used for connection tracking of ICMP,
// but we don't want to actually filter on them.
return fp[firewall.PortAny].match(p, c, caPool)
}
var port int32
if p.Fragment {
@@ -804,11 +830,9 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
return true
}
for _, group := range groups {
if group == "any" {
if slices.Contains(groups, "any") {
return true
}
}
if host == "any" {
return true
@@ -917,7 +941,7 @@ type rule struct {
CASha string
}
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) {
r := rule{}
m, ok := p.(map[string]any)
@@ -948,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
}
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
l.Warn("group was an array with a single value, converting to simple value",
"table", table,
"rule", i,
)
m["group"] = v[0]
}
@@ -1018,54 +1045,56 @@ func (r *rule) sanity() error {
}
}
if r.Code != "" {
return fmt.Errorf("code specified as [%s]. Support for 'code' will be dropped in a future release, as it has never been functional", r.Code)
}
//todo alert on cidr-any
return nil
}
func parsePort(s string) (startPort, endPort int32, err error) {
func parsePort(s string) (int32, int32, error) {
var err error
const notAPort int32 = -2
if s == "any" {
startPort = firewall.PortAny
endPort = firewall.PortAny
return firewall.PortAny, firewall.PortAny, nil
}
if s == "fragment" {
return firewall.PortFragment, firewall.PortFragment, nil
}
if !strings.Contains(s, `-`) {
rPort, err := strconv.Atoi(s)
if err != nil {
return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s)
}
return int32(rPort), int32(rPort), nil
}
} else if s == "fragment" {
startPort = firewall.PortFragment
endPort = firewall.PortFragment
} else if strings.Contains(s, `-`) {
sPorts := strings.SplitN(s, `-`, 2)
sPorts[0] = strings.Trim(sPorts[0], " ")
sPorts[1] = strings.Trim(sPorts[1], " ")
for i := range sPorts {
sPorts[i] = strings.Trim(sPorts[i], " ")
}
if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
}
rStartPort, err := strconv.Atoi(sPorts[0])
if err != nil {
return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
}
rEndPort, err := strconv.Atoi(sPorts[1])
if err != nil {
return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
}
startPort = int32(rStartPort)
endPort = int32(rEndPort)
startPort := int32(rStartPort)
endPort := int32(rEndPort)
if startPort == firewall.PortAny {
endPort = firewall.PortAny
}
} else {
rPort, err := strconv.Atoi(s)
if err != nil {
return 0, 0, fmt.Errorf("was not a number; `%s`", s)
}
startPort = int32(rPort)
endPort = startPort
}
return
return startPort, endPort, nil
}

View File

@@ -1,10 +1,10 @@
package firewall
import (
"context"
"log/slog"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
)
// ConntrackCache is used as a local routine cache to know if a given flow
@@ -15,41 +15,49 @@ type ConntrackCacheTicker struct {
cacheV uint64
cacheTick atomic.Uint64
l *slog.Logger
cache ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
l: l,
cache: ConntrackCache{},
}
go c.tick(d)
go c.tick(ctx, d)
return c
}
func (c *ConntrackCacheTicker) tick(d time.Duration) {
func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
t := time.NewTicker(d)
defer t.Stop()
for {
time.Sleep(d)
select {
case <-ctx.Done():
return
case <-t.C:
c.cacheTick.Add(1)
}
}
}
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
func (c *ConntrackCacheTicker) Get() ConntrackCache {
if c == nil {
return nil
}
if tick := c.cacheTick.Load(); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
if c.l.Enabled(context.Background(), slog.LevelDebug) {
c.l.Debug("resetting conntrack cache", "len", ll)
}
c.cache = make(ConntrackCache, ll)
}

69
firewall/cache_test.go Normal file
View File

@@ -0,0 +1,69 @@
package firewall
import (
"bytes"
"log/slog"
"strings"
"testing"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
// The tests below pin the log format produced by ConntrackCacheTicker.Get
// so changes cannot silently break what operators are grepping for. The
// ticker's internal state (cache + cacheTick) is poked directly to avoid
// racing a goroutine-driven tick in tests.
func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker {
t.Helper()
c := &ConntrackCacheTicker{
l: l,
cache: make(ConntrackCache, cacheLen),
}
for i := 0; i < cacheLen; i++ {
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
}
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
return c
}
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 3)
c.Get()
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String())
}
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 2)
c.Get()
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
}
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo)
c := newFixedTicker(t, l, 5)
c.Get()
assert.Empty(t, buf.String())
}
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 0)
c.Get()
assert.Empty(t, buf.String())
}

View File

@@ -22,7 +22,10 @@ const (
type Packet struct {
LocalAddr netip.Addr
RemoteAddr netip.Addr
// LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP.
LocalPort uint16
// RemotePort is the source port for incoming traffic, or the destination port for outgoing.
// For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier
RemotePort uint16
Protocol uint8
Fragment bool
@@ -46,6 +49,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
proto = "tcp"
case ProtoICMP:
proto = "icmp"
case ProtoICMPv6:
proto = "icmpv6"
case ProtoUDP:
proto = "udp"
default:

View File

@@ -3,13 +3,13 @@ package nebula
import (
"bytes"
"errors"
"log/slog"
"math"
"net/netip"
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
@@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) {
}
func TestFirewall_AddRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@@ -87,9 +86,10 @@ func TestFirewall_AddRule(t *testing.T) {
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
//no matter what port is given for icmp, it should end up as "any"
assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any)
assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
@@ -176,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) {
}
func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
@@ -253,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) {
}
func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
@@ -484,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -543,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) {
}
func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -632,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) {
}
func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
@@ -670,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -734,10 +728,152 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
func TestFirewall_ICMPPortBehavior(t *testing.T) {
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
network := netip.MustParsePrefix("1.2.3.4/24")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host1",
networks: []netip.Prefix{network},
groups: []string{"default-group"},
issuer: "signer-shasum",
},
InvertedGroups: map[string]struct{}{"default-group": {}},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
cp := cert.NewCAPool()
templ := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
Protocol: firewall.ProtoICMP,
Fragment: false,
}
t.Run("ICMP allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
})
t.Run("nonzero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
})
})
t.Run("Any proto, some ports allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 80
p.RemotePort = 80
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
})
t.Run("Any proto, any port", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, allowed", func(t *testing.T) {
resetConntrack(fw)
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
})
t.Run("nonzero ports, allowed", func(t *testing.T) {
resetConntrack(fw)
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
//different ID is blocked
p.RemotePort++
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
})
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
ob := &bytes.Buffer{}
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
@@ -900,53 +1036,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
require.NoError(t, err)
conf := config.NewC(l)
conf := config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
@@ -955,28 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) {
func TestAddFirewallRulesFromConfig(t *testing.T) {
l := test.NewLogger()
// Test adding tcp rule
conf := config.NewC(l)
conf := config.NewC(test.NewLogger())
mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule no port
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
@@ -984,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
@@ -999,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
// Test Add error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
@@ -1082,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
}
func TestFirewall_convertRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
// Ensure group array of 1 is converted and a warning is printed
c := map[string]any{
@@ -1092,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) {
}
r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value")
assert.Contains(t, ob.String(), "table=test")
assert.Contains(t, ob.String(), "rule=1")
require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups)
@@ -1118,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) {
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
@@ -1234,7 +1377,7 @@ type testsetup struct {
fw *Firewall
}
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{
name: "me",
networks: myPrefixes,
@@ -1245,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
return newSetupFromCert(t, l, c)
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
@@ -1262,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out

24
go.mod
View File

@@ -1,9 +1,10 @@
module github.com/slackhq/nebula
go 1.25
go 1.25.0
require (
dario.cat/mergo v1.0.2
filippo.io/bigmod v0.1.0
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
@@ -12,26 +13,25 @@ require (
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4
github.com/miekg/dns v1.1.70
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/miekg/dns v1.1.72
github.com/miekg/pkcs11 v1.1.2
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.23.2
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.4
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1
go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.47.0
golang.org/x/crypto v0.50.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/sys v0.40.0
golang.org/x/term v0.39.0
golang.org/x/net v0.52.0
golang.org/x/sync v0.20.0
golang.org/x/sys v0.43.0
golang.org/x/term v0.42.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
golang.zx2c4.com/wireguard/windows v0.6.1
google.golang.org/protobuf v1.36.11
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
@@ -49,7 +49,7 @@ require (
github.com/prometheus/procfs v0.16.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.40.0 // indirect
golang.org/x/tools v0.43.0 // indirect
)

44
go.sum
View File

@@ -1,6 +1,8 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
filippo.io/bigmod v0.1.0 h1:UNzDk7y9ADKST+axd9skUpBQeW7fG2KrTZyOE4uGQy8=
filippo.io/bigmod v0.1.0/go.mod h1:OjOXDNlClLblvXdwgFFOQFJEocLhhtai8vGLy0JCZlI=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -83,10 +85,10 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA=
github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/miekg/pkcs11 v1.1.2 h1:/VxmeAX5qU6Q3EwafypogwWbYryHFmF2RpkJmw3m4MQ=
github.com/miekg/pkcs11 v1.1.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@@ -131,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
@@ -162,16 +162,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -191,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -223,8 +223,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -233,8 +233,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
golang.zx2c4.com/wireguard/windows v0.6.1 h1:XMaKojH1Hs/raMrmnir4n35nTvzvWj7NmSYzHn2F4qU=
golang.zx2c4.com/wireguard/windows v0.6.1/go.mod h1:04aqInu5GYuTFvMuDw/rKBAF7mHrltW/3rekpfbbZDM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=

View File

@@ -2,11 +2,12 @@ package nebula
import (
"bytes"
"context"
"log/slog"
"net/netip"
"time"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
)
@@ -18,8 +19,11 @@ import (
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
f.l.Error("Failed to generate index",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false
}
@@ -39,28 +43,32 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
crt := cs.getCertificate(v)
if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate is available")
f.l.Error("Unable to handshake with host because no certificate is available",
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false
}
crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
f.l.Error("Unable to handshake with host because no certificate handshake bytes is available",
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Failed to create connection state")
f.l.Error("Failed to create connection state",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false
}
hh.hostinfo.ConnectionState = ci
@@ -76,9 +84,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
f.l.Error("Failed to marshal handshake message",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"certVersion", v,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false
}
@@ -86,8 +97,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
f.l.Error("Failed to call noise.WriteMessage",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false
}
@@ -104,18 +118,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available")
f.l.Error("Unable to handshake with host because no certificate is available",
"from", via,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", cs.initiatingVersion,
)
return
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
f.l.Error("Failed to create connection state",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -124,26 +141,32 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
f.l.Error("Failed to call noise.ReadMessage",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
f.l.Error("Failed unmarshal handshake message",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
f.l.Info("Handshake did not contain a certificate",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -154,23 +177,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp)
if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
attrs := []slog.Attr{
slog.Any("error", err),
slog.Any("from", via),
slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}),
slog.Any("certVpnNetworks", rc.Networks()),
slog.String("certFingerprint", fp),
}
if f.l.Enabled(context.Background(), slog.LevelDebug) {
attrs = append(attrs, slog.Any("cert", rc))
}
e.Info("Invalid certificate from host")
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
// callers grow conditionally, which has no pair-form equivalent.
//nolint:sloglint
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
return
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
f.l.Info("public key mismatch between certificate and handshake",
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"cert", remoteCert,
)
return
}
@@ -178,12 +208,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{
"from": via,
"handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Might be unable to handshake with host due to missing certificate version",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"cert", remoteCert,
)
}
} else {
// Record the certificate we are actually using
@@ -192,10 +223,12 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("No networks in certificate")
f.l.Info("No networks in certificate",
"error", err,
"from", via,
"cert", remoteCert,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -209,12 +242,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
f.l.Error("Refusing to handshake with myself",
"vpnNetworks", vpnNetworks,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
vpnAddrs[i] = network.Addr()
@@ -226,20 +262,28 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
if !via.IsRelayed {
// We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
Debug("lighthouse.remote_allow_list denied incoming handshake")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", vpnAddrs,
"from", via,
)
}
return
}
}
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
f.l.Error("Failed to generate index",
"error", err,
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -257,18 +301,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
},
}
msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs,
"from": via,
"certName": certName,
"certVersion": certVersion,
"fingerprint": fingerprint,
"issuer": issuer,
"initiatorIndex": hs.Details.InitiatorIndex,
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
msgRxL := f.l.With(
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
@@ -280,8 +324,9 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available",
"myCertVersion", ci.myCert.Version(),
)
return
}
@@ -291,32 +336,43 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
f.l.Error("Failed to marshal handshake message",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
f.l.Error("Failed to call noise.WriteMessage",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
f.l.Error("Noise did not arrive at a key",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -358,13 +414,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
if !via.IsRelayed {
err := f.outside.WriteTo(msg, via.UdpAddr)
if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
f.l.Error("Failed to send handshake message",
"vpnAddrs", existing.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
"error", err,
)
} else {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
f.l.Info("Handshake message sent",
"vpnAddrs", existing.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
)
}
return
} else {
@@ -374,50 +437,67 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
f.l.Info("Handshake message sent",
"vpnAddrs", existing.vpnAddrs,
"relay", via.relayHI.vpnAddrs[0],
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
)
return
}
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake too old")
f.l.Info("Handshake too old",
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"oldHandshakeTime", existing.lastHandshakeTime,
"newHandshakeTime", hostinfo.lastHandshakeTime,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
Error("Failed to add HostInfo due to localIndex collision")
f.l.Error("Failed to add HostInfo due to localIndex collision",
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"localIndex", hostinfo.localIndexId,
"collision", existing.vpnAddrs,
)
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to add HostInfo to HostMap")
f.l.Error("Failed to add HostInfo to HostMap",
"error", err,
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
}
@@ -426,15 +506,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed {
err = f.outside.WriteTo(msg, via.UdpAddr)
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
log := f.l.With(
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
if err != nil {
log.WithError(err).Error("Failed to send handshake")
log.Error("Failed to send handshake", "error", err)
} else {
log.Info("Handshake message sent")
}
@@ -448,20 +533,29 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
// it's correctly marked as working.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
f.l.Info("Handshake message sent",
"vpnAddrs", vpnAddrs,
"relay", via.relayHI.vpnAddrs[0],
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
}
f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
// Don't wait for UpdateWorker
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
f.lightHouse.TriggerUpdate()
}
return
}
@@ -478,7 +572,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
if !via.IsRelayed {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
)
}
return false
}
}
@@ -486,18 +585,24 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
f.l.Error("Failed to call noise.ReadMessage",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"header", h,
)
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
f.l.Error("Noise did not arrive at a key",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
// the handshake state machine. Tear it down
@@ -507,8 +612,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
f.l.Error("Failed unmarshal handshake message",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
@@ -516,10 +625,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
f.l.Info("Handshake did not contain a certificate",
"error", err,
"from", via,
"vpnAddrs", hostinfo.vpnAddrs,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
return true
}
@@ -530,32 +641,41 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp).
WithField("certVpnNetworks", rc.Networks())
if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
attrs := []slog.Attr{
slog.Any("error", err),
slog.Any("from", via),
slog.Any("vpnAddrs", hostinfo.vpnAddrs),
slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}),
slog.String("certFingerprint", fp),
slog.Any("certVpnNetworks", rc.Networks()),
}
if f.l.Enabled(context.Background(), slog.LevelDebug) {
attrs = append(attrs, slog.Any("cert", rc))
}
e.Info("Invalid certificate from host")
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
// callers grow conditionally, which has no pair-form equivalent.
//nolint:sloglint
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
return true
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
f.l.Info("public key mismatch between certificate and handshake",
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cert", remoteCert,
)
return true
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("No networks in certificate")
f.l.Info("No networks in certificate",
"error", err,
"from", via,
"vpnAddrs", hostinfo.vpnAddrs,
"cert", remoteCert,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
return true
}
@@ -596,12 +716,14 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
// Ensure the right host responded
if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
f.l.Info("Incorrect host responded to handshake",
"intendedVpnAddrs", hostinfo.vpnAddrs,
"haveVpnNetworks", vpnNetworks,
"from", via,
"certName", certName,
"certVersion", certVersion,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// Release our old handshake from pending, it should not continue
f.handshakeManager.DeleteHostInfo(hostinfo)
@@ -613,10 +735,11 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(via)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnNetworks", vpnNetworks).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")
f.l.Info("Blocked addresses for handshakes",
"blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(),
"vpnNetworks", vpnNetworks,
"remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()),
)
// Swap the packet store to benefit the original intended recipient
newHH.packetStore = hh.packetStore
@@ -634,15 +757,20 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore))
msgRxL := f.l.With(
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"durationNs", duration,
"sentCachedPackets", len(hh.packetStore),
)
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
@@ -658,8 +786,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Sending stored packets",
"count", len(hh.packetStore),
)
}
if len(hh.packetStore) > 0 {
@@ -674,5 +804,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
f.metricHandshakes.Update(duration)
// Don't wait for UpdateWorker
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
f.lightHouse.TriggerUpdate()
}
return false
}

View File

@@ -6,13 +6,13 @@ import (
"crypto/rand"
"encoding/binary"
"errors"
"log/slog"
"net/netip"
"slices"
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
@@ -59,7 +59,7 @@ type HandshakeManager struct {
metricInitiated metrics.Counter
metricTimedOut metrics.Counter
f *Interface
l *logrus.Logger
l *slog.Logger
// can be used to trigger outbound handshake for the given vpnIp
trigger chan netip.Addr
@@ -78,32 +78,32 @@ type HandshakeHostInfo struct {
hostinfo *HostInfo
}
func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
if len(hh.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
copy(tempPacket, packet)
hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
if l.Level >= logrus.DebugLevel {
hh.hostinfo.logger(l).
WithField("length", len(hh.packetStore)).
WithField("stored", true).
Debugf("Packet store")
if l.Enabled(context.Background(), slog.LevelDebug) {
hh.hostinfo.logger(l).Debug("Packet store",
"length", len(hh.packetStore),
"stored", true,
)
}
} else {
m.dropped.Inc(1)
if l.Level >= logrus.DebugLevel {
hh.hostinfo.logger(l).
WithField("length", len(hh.packetStore)).
WithField("stored", false).
Debugf("Packet store")
if l.Enabled(context.Background(), slog.LevelDebug) {
hh.hostinfo.logger(l).Debug("Packet store",
"length", len(hh.packetStore),
"stored", false,
)
}
}
}
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
indexes: map[uint32]*HandshakeHostInfo{},
@@ -140,7 +140,7 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
// First remote allow list check before we know the vpnIp
if !via.IsRelayed {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via)
return
}
}
@@ -183,12 +183,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo := hh.hostinfo
// If we are out of time, clean up
if hh.counter >= hm.config.retries {
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
WithField("initiatorIndex", hh.hostinfo.localIndexId).
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("durationNs", time.Since(hh.startTime).Nanoseconds()).
Info("Handshake timed out")
hh.hostinfo.logger(hm.l).Info("Handshake timed out",
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
"initiatorIndex", hh.hostinfo.localIndexId,
"remoteIndex", hh.hostinfo.remoteIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"durationNs", time.Since(hh.startTime).Nanoseconds(),
)
hm.metricTimedOut.Inc(1)
hm.DeleteHostInfo(hostinfo)
return
@@ -241,10 +242,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
hostinfo.logger(hm.l).WithField("udpAddr", addr).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
hostinfo.logger(hm.l).Error("Failed to send handshake message",
"udpAddr", addr,
"initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"error", err,
)
} else {
sentTo = append(sentTo, addr)
@@ -254,19 +257,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
// so only log when the list of remotes has changed
if remotesHaveChanged {
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
} else if hm.l.Level >= logrus.DebugLevel {
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Debug("Handshake message sent")
hostinfo.logger(hm.l).Info("Handshake message sent",
"udpAddrs", sentTo,
"initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
} else if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(hm.l).Debug("Handshake message sent",
"udpAddrs", sentTo,
"initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
}
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to
@@ -281,7 +286,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String())
hm.f.Handshake(relay)
continue
}
@@ -292,7 +297,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
}
m := NebulaControl{
@@ -326,17 +331,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
"relay": relay}).
Info("send CreateRelayRequest")
hm.l.Info("send CreateRelayRequest",
"relayFrom", hm.f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", idx,
"relay", relay,
)
}
}
continue
@@ -344,14 +347,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
switch existingRelay.State {
case Established:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String())
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
@@ -383,28 +386,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
"relay": relay}).
Info("send CreateRelayRequest")
hm.l.Info("send CreateRelayRequest",
"relayFrom", hm.f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", existingRelay.LocalIndex,
"relay", relay,
)
}
case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default:
hostinfo.logger(hm.l).
WithField("vpnIp", vpnIp).
WithField("state", existingRelay.State).
WithField("relay", relay).
Errorf("Relay unexpected state")
hostinfo.logger(hm.l).Error("Relay unexpected state",
"vpnIp", vpnIp,
"state", existingRelay.State,
"relay", relay,
)
}
}
@@ -549,9 +550,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex")
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs,
)
}
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
@@ -571,9 +573,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex")
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs,
)
}
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
@@ -590,7 +593,7 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
hm.Lock()
defer hm.Unlock()
for i := 0; i < 32; i++ {
for range 32 {
index, err := generateIndex(hm.l)
if err != nil {
return err
@@ -629,10 +632,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
hm.indexes = map[uint32]*HandshakeHostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Pending hostmap hostInfo deleted")
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.Debug("Pending hostmap hostInfo deleted",
"hostMap", m{"mapTotalSize": len(hm.vpnIps),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
)
}
}
@@ -700,7 +704,7 @@ func (hm *HandshakeManager) EmitStats() {
// Utility functions below
func generateIndex(l *logrus.Logger) (uint32, error) {
func generateIndex(l *slog.Logger) (uint32, error) {
b := make([]byte, 4)
// Let zero mean we don't know the ID, so don't generate zero
@@ -708,16 +712,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
for index == 0 {
_, err := rand.Read(b)
if err != nil {
l.Errorln(err)
l.Error("Failed to generate index", "error", err)
return 0, err
}
index = binary.BigEndian.Uint32(b)
}
if l.Level >= logrus.DebugLevel {
l.WithField("index", index).
Debug("Generated index")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Generated index", "index", index)
}
return index, nil
}

View File

@@ -1,9 +1,11 @@
package nebula
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
@@ -13,10 +15,10 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
)
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
@@ -60,7 +62,7 @@ type HostMap struct {
RemoteIndexes map[uint32]*HostInfo
Hosts map[netip.Addr]*HostInfo
preferredRanges atomic.Pointer[[]netip.Prefix]
l *logrus.Logger
l *slog.Logger
}
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
@@ -313,7 +315,7 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap {
hm := newHostMap(l)
hm.reload(c, true)
@@ -321,13 +323,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
hm.reload(c, false)
})
l.WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created")
l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges())
return hm
}
func newHostMap(l *logrus.Logger) *HostMap {
func newHostMap(l *slog.Logger) *HostMap {
return &HostMap{
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
@@ -346,7 +347,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
hm.l.Warn("Failed to parse preferred ranges, ignoring",
"error", err,
"range", rawPreferredRanges,
)
continue
}
@@ -355,7 +359,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
if !initial {
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
hm.l.Info("preferred_ranges changed",
"oldPreferredRanges", *oldRanges,
"newPreferredRanges", preferredRanges,
)
}
}
}
@@ -488,10 +495,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad
hm.Indexes = map[uint32]*HostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.Debug("Hostmap hostInfo deleted",
"hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
)
}
if isLastHostinfo {
@@ -604,9 +612,9 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns {
if f.dnsServer != nil {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
f.dnsServer.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
}
for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
@@ -615,10 +623,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
Debug("Hostmap vpnIp added")
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.Debug("Hostmap vpnIp added",
"hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}},
)
}
}
@@ -784,18 +793,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica
}
}
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// logger returns a derived slog.Logger with per-hostinfo fields pre-bound.
func (i *HostInfo) logger(l *slog.Logger) *slog.Logger {
if i == nil {
return logrus.NewEntry(l)
return l
}
li := l.WithField("vpnAddrs", i.vpnAddrs).
WithField("localIndex", i.localIndexId).
WithField("remoteIndex", i.remoteIndexId)
li := l.With(
"vpnAddrs", i.vpnAddrs,
"localIndex", i.localIndexId,
"remoteIndex", i.remoteIndexId,
)
if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Certificate.Name())
li = li.With("certName", peerCert.Certificate.Name())
}
}
@@ -804,14 +816,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage
var finalAddrs []netip.Addr
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
if l.Level >= logrus.TraceLevel {
l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName")
if l.Enabled(context.Background(), logging.LevelTrace) {
l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName",
"interfaceName", i.Name,
"allow", allow,
)
}
if !allow {
@@ -829,8 +844,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
}
if !addr.IsValid() {
if l.Level >= logrus.DebugLevel {
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("addr was invalid", "localAddr", rawAddr)
}
continue
}
@@ -838,8 +853,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel {
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
if l.Enabled(context.Background(), logging.LevelTrace) {
l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow",
"localAddr", addr,
"allowed", isAllowed,
)
}
if !isAllowed {
continue

View File

@@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
func TestHostMap_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
c := config.NewC(test.NewLogger())
hm := NewHostMapFromConfig(l, c)

View File

@@ -1,5 +1,4 @@
//go:build e2e_testing
// +build e2e_testing
package nebula

119
inside.go
View File

@@ -1,9 +1,10 @@
package nebula
import (
"context"
"log/slog"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
@@ -14,8 +15,11 @@ import (
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while validating outbound packet",
"packet", packet,
"error", err,
)
}
return
}
@@ -35,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
f.l.Error("Failed to forward to tun", "error", err)
}
}
// Otherwise, drop. On linux, we should never see these packets - Linux
@@ -54,10 +58,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
"vpnAddr", fwPacket.RemoteAddr,
"fwPacket", fwPacket,
)
}
return
}
@@ -72,11 +77,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping outbound packet",
"fwPacket", fwPacket,
"reason", dropReason,
)
}
}
}
@@ -93,7 +98,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
_, err := f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
f.l.Error("Failed to write to tun", "error", err)
}
}
@@ -108,11 +113,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
}
if len(out) > iputil.MaxRejectPacketSize {
if f.l.GetLevel() >= logrus.InfoLevel {
f.l.
WithField("packet", packet).
WithField("outPacket", out).
Info("rejectOutside: packet too big, not sending")
if f.l.Enabled(context.Background(), slog.LevelInfo) {
f.l.Info("rejectOutside: packet too big, not sending",
"packet", packet,
"outPacket", out,
)
}
return
}
@@ -184,10 +189,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
// This would also need to interact with unsafe_route updates through reloading the config or
// use of the use_system_route_table option
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("destination", destinationAddr).
WithField("originalGateway", gatewayAddr).
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways",
"destination", destinationAddr,
"originalGateway", gatewayAddr,
)
}
for i := range gateways {
@@ -213,17 +219,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
fp := &firewall.Packet{}
err := newPacket(p, false, fp)
if err != nil {
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
return
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).
WithField("reason", dropReason).
Debugln("dropping cached packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping cached packet",
"fwPacket", fp,
"reason", dropReason,
)
}
return
}
@@ -239,9 +246,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message
})
if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", vpnAddr).
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes",
"vpnAddr", vpnAddr,
)
}
return
}
@@ -297,12 +305,12 @@ func (f *Interface) SendVia(via *HostInfo,
if noiseutil.EncryptLockNeeded {
via.ConnectionState.writeLock.Unlock()
}
via.logger(f.l).
WithField("outCap", cap(out)).
WithField("payloadLen", len(ad)).
WithField("headerLen", len(out)).
WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()).
Error("SendVia out buffer not large enough for relay")
via.logger(f.l).Error("SendVia out buffer not large enough for relay",
"outCap", cap(out),
"payloadLen", len(ad),
"headerLen", len(out),
"cipherOverhead", via.ConnectionState.eKey.Overhead(),
)
return
}
@@ -322,12 +330,12 @@ func (f *Interface) SendVia(via *HostInfo,
via.ConnectionState.writeLock.Unlock()
}
if err != nil {
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
return
}
err = f.writers[0].WriteTo(out, via.remote)
if err != nil {
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
}
f.connectionManager.RelayUsed(relay.LocalIndex)
}
@@ -366,8 +374,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Lighthouse update triggered for punch due to rebind counter",
"vpnAddrs", hostinfo.vpnAddrs,
)
}
}
@@ -377,24 +387,30 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
"error", err,
"udpAddr", remote,
"counter", c,
"attemptedCounter", c,
)
return
}
if remote.IsValid() {
err = f.writers[q].WriteTo(out, remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
"error", err,
"udpAddr", remote,
)
}
} else if hostinfo.remote.IsValid() {
err = f.writers[q].WriteTo(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
"error", err,
"udpAddr", remote,
)
}
} else {
// Try to send via a relay
@@ -402,7 +418,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
"relay", relayIP,
"error", err,
)
continue
}
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)

View File

@@ -1,5 +1,4 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
// +build darwin dragonfly freebsd netbsd openbsd
package nebula

View File

@@ -1,5 +1,4 @@
//go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd
// +build !darwin,!dragonfly,!freebsd,!netbsd,!openbsd
package nebula

View File

@@ -6,15 +6,15 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -31,7 +31,7 @@ type InterfaceConfig struct {
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
DnsServer *dnsServer
HandshakeManager *HandshakeManager
lightHouse *LightHouse
connectionManager *connectionManager
@@ -48,7 +48,7 @@ type InterfaceConfig struct {
reQueryWait time.Duration
ConntrackCacheTimeout time.Duration
l *logrus.Logger
l *slog.Logger
}
type Interface struct {
@@ -59,7 +59,7 @@ type Interface struct {
firewall *Firewall
connectionManager *connectionManager
handshakeManager *HandshakeManager
serveDns bool
dnsServer *dnsServer
createTime time.Time
lightHouse *LightHouse
myBroadcastAddrsTable *bart.Lite
@@ -87,14 +87,22 @@ type Interface struct {
conntrackCacheTimeout time.Duration
ctx context.Context
writers []udp.Conn
readers []io.ReadWriteCloser
wg sync.WaitGroup
// fatalErr holds the first unexpected reader error that caused shutdown.
// nil means "no fatal error" (yet)
fatalErr atomic.Pointer[error]
// triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr
triggerShutdown func()
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger
l *slog.Logger
}
type EncWriter interface {
@@ -165,12 +173,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
cs := c.pki.getCertState()
ifce := &Interface{
ctx: ctx,
pki: c.pki,
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
firewall: c.Firewall,
serveDns: c.ServeDns,
dnsServer: c.DnsServer,
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
@@ -211,19 +220,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() {
func (f *Interface) activate() error {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
if err != nil {
f.l.WithError(err).Error("Failed to get udp listen address")
f.l.Error("Failed to get udp listen address", "error", err)
}
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
WithField("build", f.version).WithField("udpAddr", addr).
WithField("boringcrypto", boringEnabled()).
WithField("fips140", fips140.Enabled()).
Info("Nebula interface is active")
f.l.Info("Nebula interface is active",
"interface", f.inside.Name(),
"networks", f.myVpnNetworks,
"build", f.version,
"udpAddr", addr,
"boringcrypto", boringEnabled(),
"fips140", fips140.Enabled(),
)
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
@@ -240,33 +252,58 @@ func (f *Interface) activate() {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
f.l.Fatal(err)
return err
}
}
f.readers[i] = reader
}
if err := f.inside.Activate(); err != nil {
f.wg.Add(1) // for us to wait on Close() to return
if err = f.inside.Activate(); err != nil {
f.wg.Done()
f.inside.Close()
f.l.Fatal(err)
}
return err
}
func (f *Interface) run() {
return nil
}
func (f *Interface) run() (func() error, error) {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
go f.listenOut(i)
f.wg.Go(func() {
f.listenOut(i)
})
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
f.wg.Go(func() {
f.listenIn(f.readers[i], i)
})
}
return func() error {
f.wg.Wait()
if e := f.fatalErr.Load(); e != nil {
return *e
}
return nil
}, nil
}
// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one
func (f *Interface) onFatal(err error) {
swapped := f.fatalErr.CompareAndSwap(nil, &err)
if !swapped {
return
}
if f.triggerShutdown != nil {
f.triggerShutdown()
}
}
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
if i > 0 {
li = f.writers[i]
@@ -274,42 +311,47 @@ func (f *Interface) listenOut(i int) {
li = f.outside
}
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
})
if err != nil && !f.closed.Load() {
f.l.Error("Error while reading inbound packet, closing", "error", err)
f.onFatal(err)
}
f.l.Debug("underlay reader is done", "reader", i)
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
if !f.closed.Load() {
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
f.onFatal(err)
}
break
}
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
f.l.Debug("overlay reader is done", "reader", i)
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -329,7 +371,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
if initial || c.HasChanged("pki.disconnect_invalid") {
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
if !initial {
f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load())
}
}
}
@@ -343,7 +385,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
f.l.Error("Error while creating firewall during reload", "error", err)
return
}
@@ -356,10 +398,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack",
"firewallHashes", fw.GetRuleHashes(),
"oldFirewallHashes", oldFw.GetRuleHashes(),
"rulesVersion", fw.rulesVersion,
)
} else {
fw.Conntrack = conntrack
}
@@ -367,10 +410,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
f.firewall = fw
oldFw.Destroy()
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed")
f.l.Info("New firewall has been installed",
"firewallHashes", fw.GetRuleHashes(),
"oldFirewallHashes", oldFw.GetRuleHashes(),
"rulesVersion", fw.rulesVersion,
)
}
func (f *Interface) reloadSendRecvError(c *config.C) {
@@ -392,8 +436,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
}
}
f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
Info("Loaded send_recv_error config")
f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String())
}
}
@@ -416,8 +459,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) {
}
}
f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()).
Info("Loaded accept_recv_error config")
f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String())
}
}
@@ -484,23 +526,23 @@ func (f *Interface) GetCertState() *CertState {
}
func (f *Interface) Close() error {
var errs []error
f.closed.Store(true)
for _, u := range f.writers {
// Release the udp readers
for i, u := range f.writers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing udp socket")
}
}
for i, r := range f.readers {
if i == 0 {
continue // f.readers[0] is f.inside, which we want to save for last
}
if err := r.Close(); err != nil {
f.l.WithError(err).Error("Error while closing tun reader")
f.l.Error("Error while closing udp socket", "error", err, "writer", i)
errs = append(errs, err)
}
}
// Release the tun device
return f.inside.Close()
// Release the tun device (closing the tun also closes all readers)
closeErr := f.inside.Close()
if closeErr != nil {
errs = append(errs, closeErr)
}
f.wg.Done()
return errors.Join(errs...)
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
@@ -15,10 +16,10 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
)
@@ -69,18 +70,19 @@ type LightHouse struct {
// Addr's of relays that can be used by peers to access me
relaysForMe atomic.Pointer[[]netip.Addr]
updateTrigger chan struct{}
queryChan chan netip.Addr
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *logrus.Logger
l *slog.Logger
}
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 {
@@ -105,6 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
nebulaPort: nebulaPort,
punchConn: pc,
punchy: p,
updateTrigger: make(chan struct{}, 1),
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
@@ -131,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
case *util.ContextualError:
v.Log(l)
case error:
l.WithError(err).Error("failed to reload lighthouse")
l.Error("failed to reload lighthouse", "error", err)
}
})
@@ -203,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
addr := addrs[0].Unmap()
if lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range",
"addr", rawAddr,
"entry", i+1,
)
continue
}
@@ -222,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
if !initial {
lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load())
lh.l.Info("lighthouse.interval changed",
"interval", lh.interval.Load(),
)
if lh.updateCancel != nil {
// May not always have a running routine
@@ -316,6 +323,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed")
lh.TriggerUpdate()
}
}
@@ -333,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
for _, v := range c.GetStringSlice("relay.relays", nil) {
configRIP, err := netip.ParseAddr(v)
if err != nil {
lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed")
lh.l.Warn("Parse relay from config failed",
"relay", v,
"error", err,
)
} else {
lh.l.WithField("relay", v).Info("Read relay from config")
lh.l.Info("Read relay from config", "relay", v)
relaysForMe = append(relaysForMe, configRIP)
}
}
@@ -360,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
}
if !lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not",
"vpnAddr", addr,
"networks", lh.myVpnNetworks,
)
}
out[i] = addr
}
@@ -432,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
}
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work",
"vpnAddr", vpnAddr,
"networks", lh.myVpnNetworks,
"entry", i+1,
)
}
vals, ok := v.([]any)
@@ -534,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok {
debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug)
for _, addr := range allVpnAddrs {
srm := lh.addrMap[addr]
if srm == rm {
delete(lh.addrMap, addr)
if lh.l.Level >= logrus.DebugLevel {
lh.l.Debugf("deleting %s from lighthouse.", addr)
if debugEnabled {
lh.l.Debug("deleting from lighthouse", "vpnAddr", addr)
}
}
}
@@ -656,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow")
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
"vpnAddrs", vpnAddrs,
"udpAddr", to,
"allow", allow,
)
}
if !allow {
return false
@@ -675,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
udpAddr := protoV4AddrPortToNetAddrPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
Trace("remoteAllowList.Allow")
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
"vpnAddr", vpnAddr,
"udpAddr", udpAddr,
"allow", allow,
)
}
if !allow {
@@ -695,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
udpAddr := protoV6AddrPortToNetAddrPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
Trace("remoteAllowList.Allow")
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
"vpnAddr", vpnAddr,
"udpAddr", udpAddr,
"allow", allow,
)
}
if !allow {
@@ -713,23 +739,16 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
l := lh.GetLighthouses()
for i := range l {
if l[i] == vpnAddr {
return true
}
}
return false
return slices.Contains(l, vpnAddr)
}
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
l := lh.GetLighthouses()
for i := range vpnAddrs {
for j := range l {
if l[j] == vpnAddrs[i] {
if slices.Contains(l, vpnAddrs[i]) {
return true
}
}
}
return false
}
@@ -779,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
if v == cert.Version1 {
if !addr.Is4() {
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr).
Error("Can't query lighthouse for v6 address using a v1 protocol")
lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol",
"queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue
}
@@ -791,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
v1Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v1 query payload")
lh.l.Error("Failed to marshal lighthouse v1 query payload",
"error", err,
"queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -808,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
v2Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v2 query payload")
lh.l.Error("Failed to marshal lighthouse v2 query payload",
"error", err,
"queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -819,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried++
} else {
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v)
lh.l.Debug("unsupported protocol version",
"op", "query",
"queryVpnAddr", addr,
"version", v,
)
continue
}
}
@@ -848,11 +877,24 @@ func (lh *LightHouse) StartUpdateWorker() {
return
case <-clockSource.C:
continue
case <-lh.updateTrigger:
continue
}
}
}()
}
// TriggerUpdate requests an immediate lighthouse update. This is a non-blocking
// operation intended to be called after a handshake completes with a lighthouse,
// so the lighthouse has our current addresses without waiting for the next
// periodic update.
func (lh *LightHouse) TriggerUpdate() {
select {
case lh.updateTrigger <- struct{}{}:
default:
}
}
func (lh *LightHouse) SendUpdate() {
var v4 []*V4AddrPort
var v6 []*V6AddrPort
@@ -898,8 +940,9 @@ func (lh *LightHouse) SendUpdate() {
if v == cert.Version1 {
if v1Update == nil {
if !lh.myVpnNetworks[0].Addr().Is4() {
lh.l.WithField("lighthouseAddr", lhVpnAddr).
Warn("cannot update lighthouse using v1 protocol without an IPv4 address")
lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address",
"lighthouseAddr", lhVpnAddr,
)
continue
}
var relays []uint32
@@ -923,8 +966,10 @@ func (lh *LightHouse) SendUpdate() {
v1Update, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
Error("Error while marshaling for lighthouse v1 update")
lh.l.Error("Error while marshaling for lighthouse v1 update",
"error", err,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -950,8 +995,10 @@ func (lh *LightHouse) SendUpdate() {
v2Update, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
Error("Error while marshaling for lighthouse v2 update")
lh.l.Error("Error while marshaling for lighthouse v2 update",
"error", err,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -960,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() {
updated++
} else {
lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v)
lh.l.Debug("unsupported protocol version",
"op", "update",
"version", v,
)
continue
}
}
@@ -974,7 +1024,7 @@ type LightHouseHandler struct {
out []byte
pb []byte
meta *NebulaMeta
l *logrus.Logger
l *slog.Logger
}
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
@@ -1023,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet")
lhh.l.Error("Failed to unmarshal lighthouse packet",
"error", err,
"vpnAddrs", fromVpnAddrs,
"udpAddr", rAddr,
)
return
}
if n.Details == nil {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update")
lhh.l.Error("Invalid lighthouse update",
"vpnAddrs", fromVpnAddrs,
"udpAddr", rAddr,
)
return
}
@@ -1058,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I don't answer queries, but received from: ", addr)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("I don't answer queries, but received one", "from", addr)
}
return
}
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Dropping malformed HostQuery",
"from", fromVpnAddrs,
"details", n.Details,
)
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("invalid vpn addr for v1 handleHostQuery",
"vpnAddrs", fromVpnAddrs,
"queryVpnAddr", queryVpnAddr,
)
}
return
}
@@ -1101,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
}
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
lhh.l.Error("Failed to marshal lighthouse host query reply",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return
}
@@ -1129,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok {
whereToPunch = newDest
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("unable to punch to host, no addresses in common",
"to", crt.Networks(),
)
}
}
}
@@ -1156,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
}
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for")
lhh.l.Error("Failed to marshal lighthouse host was queried for",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return
}
@@ -1198,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
}
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("version", v).Debug("unsupported protocol version")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("unsupported protocol version",
"op", "coalesceAnswers",
"version", v,
)
}
}
}
@@ -1212,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Error("dropping malformed HostQueryReply",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
}
return
}
@@ -1238,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs)
}
return
}
@@ -1262,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Host sent invalid update",
"vpnAddrs", fromVpnAddrs,
"answer", detailsVpnAddr,
)
}
return
}
@@ -1285,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
switch useVersion {
case cert.Version1:
if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message",
"vpnAddrs", fromVpnAddrs,
)
return
}
vpnAddrB := fromVpnAddrs[0].As4()
@@ -1293,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
case cert.Version2:
// do nothing, we want to send a blank message
default:
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
lhh.l.Error("invalid protocol version", "useVersion", useVersion)
return
}
ln, err := n.MarshalTo(lhh.pb)
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack")
lhh.l.Error("Failed to marshal lighthouse host update ack",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return
}
@@ -1316,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("dropping invalid HostPunchNotification",
"details", n.Details,
"error", err,
)
}
return
}
@@ -1334,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
}()
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Punching",
"vpnPeer", vpnPeer,
"logVpnAddr", logVpnAddr,
)
}
}
@@ -1360,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
if lhh.lh.punchy.GetRespond() {
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Sending a nebula test packet",
"vpnAddr", detailsVpnAddr,
)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine

View File

@@ -1,45 +0,0 @@
package nebula
import (
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
func configLogger(l *logrus.Logger, c *config.C) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return nil
}

233
logging/logger.go Normal file
View File

@@ -0,0 +1,233 @@
// Package logging wires the nebula runtime-reconfigurable slog handler used
// by nebula.Main and the nebula CLI binaries. Callers build a logger with
// NewLogger, then call ApplyConfig at startup and from a config reload
// callback to push logging.level, logging.format, and
// logging.disable_timestamp changes onto the logger without rebuilding it.
package logging
import (
"context"
"fmt"
"io"
"log/slog"
"strings"
"sync/atomic"
"time"
)
// Config is the subset of *config.C that ApplyConfig reads. Declaring it
// here keeps the logging package from depending on config directly, which
// would cycle through the shared test helpers (test.NewLogger imports
// logging, and config's tests import test). *config.C satisfies this
// interface structurally with no adapter.
type Config interface {
GetString(key, def string) string
GetBool(key string, def bool) bool
}
// LevelTrace is a custom slog level below Debug, used when logging.level is
// "trace". slog has no builtin trace level; the value is one step below
// slog.LevelDebug in slog's 4-point spacing.
const LevelTrace = slog.Level(-8)
// NewLogger returns a *slog.Logger whose level, format, and timestamp
// emission can be reconfigured at runtime via ApplyConfig and the SSH debug
// commands. The default configuration is info-level text output so log
// calls made before ApplyConfig runs still produce output. Timestamps
// follow slog's default RFC3339Nano format; set logging.disable_timestamp
// in config to suppress them.
//
// ApplyConfig and the SSH commands discover the reconfig surface via
// structural type-assertion on l.Handler(), so replacement implementations
// (tests, platform-specific sinks) need only implement the subset of
// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)}
// they care about. Callers that pass a plain *slog.Logger without these
// methods get a silent no-op; reconfiguration is always opt-in.
func NewLogger(w io.Writer) *slog.Logger {
return slog.New(NewHandler(w))
}
// NewHandler builds the *Handler that NewLogger wraps. Exported for
// platform-specific sinks (notably cmd/nebula-service/logs_windows.go)
// that want to wrap the handler with extra behavior, such as tagging each
// record with its Event Log severity, while still benefiting from all the
// level / format / timestamp / WithAttrs machinery implemented here.
func NewHandler(w io.Writer) *Handler {
root := &handlerRoot{}
root.level.Set(slog.LevelInfo)
opts := &slog.HandlerOptions{Level: &root.level}
return &Handler{
root: root,
text: slog.NewTextHandler(w, opts),
json: slog.NewJSONHandler(w, opts),
}
}
// handlerRoot carries the reconfiguration state shared by every logger
// derived from a NewHandler call. All fields are consulted on the log
// path and updated lock-free.
type handlerRoot struct {
level slog.LevelVar
disableTimestamp atomic.Bool
// jsonMode picks which of the pre-derived inner handlers Handler.Handle
// dispatches to. Flipping it propagates instantly to every derived logger
// without rebuilding or chain-replaying anything.
jsonMode atomic.Bool
}
// Handler is the slog.Handler returned by NewHandler. It holds two
// pre-derived slog handlers -- one text, one json -- both built from the
// same accumulated WithAttrs/WithGroup state. Handle picks which one to
// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes
// effect immediately across the whole process without having to rebuild
// any derived loggers.
type Handler struct {
root *handlerRoot
text slog.Handler
json slog.Handler
}
func (h *Handler) Enabled(_ context.Context, l slog.Level) bool {
return h.root.level.Level() <= l
}
func (h *Handler) Handle(ctx context.Context, r slog.Record) error {
if h.root.disableTimestamp.Load() {
r.Time = time.Time{}
}
if h.root.jsonMode.Load() {
return h.json.Handle(ctx, r)
}
return h.text.Handle(ctx, r)
}
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return h
}
return &Handler{
root: h.root,
text: h.text.WithAttrs(attrs),
json: h.json.WithAttrs(attrs),
}
}
func (h *Handler) WithGroup(name string) slog.Handler {
if name == "" {
return h
}
return &Handler{
root: h.root,
text: h.text.WithGroup(name),
json: h.json.WithGroup(name),
}
}
// SetLevel updates the effective log level. Propagates to every derived
// logger via the shared LevelVar.
func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) }
// GetLevel reports the current log level.
func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() }
// SetFormat flips the output format atomically. Valid formats are "text"
// and "json". Every derived logger sees the new format on its next Handle
// call; no rebuild or registration is required.
func (h *Handler) SetFormat(format string) error {
switch format {
case "text":
h.root.jsonMode.Store(false)
case "json":
h.root.jsonMode.Store(true)
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"})
}
return nil
}
// GetFormat reports the currently selected format name.
func (h *Handler) GetFormat() string {
if h.root.jsonMode.Load() {
return "json"
}
return "text"
}
// SetDisableTimestamp toggles whether Handle zeroes r.Time before
// dispatching (slog's builtin text/json handlers skip emitting the time
// attribute on a zero time).
func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) }
// ApplyConfig reads logging.level, logging.format, and (optionally)
// logging.disable_timestamp from c and applies them to l. The reconfig
// surface is discovered via structural type-assertion on l.Handler(), so
// foreign handlers silently opt out of whichever capabilities they do not
// implement.
//
// nebula.Main does NOT call this function on your behalf; callers that want
// config-driven log level / format / timestamp updates invoke it at
// startup and register it as a reload callback themselves. This keeps the
// library from mutating an embedder's logger without their say-so.
func ApplyConfig(l *slog.Logger, c Config) error {
h := l.Handler()
lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return err
}
if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok {
ls.SetLevel(lvl)
}
format := strings.ToLower(c.GetString("logging.format", "text"))
if fs, ok := h.(interface{ SetFormat(string) error }); ok {
if err := fs.SetFormat(format); err != nil {
return err
}
}
if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok {
ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false))
}
return nil
}
// ParseLevel converts a config-string level name ("trace", "debug", "info",
// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and
// "panic" are accepted for backwards compatibility with pre-slog configs
// and both map to slog.LevelError.
func ParseLevel(s string) (slog.Level, error) {
switch s {
case "trace":
return LevelTrace, nil
case "debug":
return slog.LevelDebug, nil
case "info":
return slog.LevelInfo, nil
case "warn", "warning":
return slog.LevelWarn, nil
case "error":
return slog.LevelError, nil
case "fatal", "panic":
return slog.LevelError, nil
default:
return 0, fmt.Errorf("not a valid logging level: %q", s)
}
}
// LevelName returns a human-readable name for a slog.Level matching the
// strings accepted by ParseLevel.
func LevelName(l slog.Level) string {
switch {
case l <= LevelTrace:
return "trace"
case l <= slog.LevelDebug:
return "debug"
case l <= slog.LevelInfo:
return "info"
case l <= slog.LevelWarn:
return "warn"
default:
return "error"
}
}

View File

@@ -0,0 +1,90 @@
package logging
import (
"context"
"io"
"log/slog"
"testing"
)
// BenchmarkLogger_* compare the handler returned by NewLogger against a
// stock slog text handler. The key thing we care about is the per-log
// cost on a logger that has been derived via .With(), because that is the
// shape subsystems store on their structs (HostInfo.logger(),
// lh.l.With("subsystem", ...), etc.) and call from hot paths.
func BenchmarkLogger_Stock_RootInfo(b *testing.B) {
l := slog.New(slog.DiscardHandler)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Nebula_RootInfo(b *testing.B) {
l := NewLogger(io.Discard)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) {
l := slog.New(slog.DiscardHandler).With(
"subsystem", "bench",
"localIndex", 1234,
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) {
l := NewLogger(io.Discard).With(
"subsystem", "bench",
"localIndex", 1234,
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
// Gated-off-path benchmarks: mimic the typical hot-path shape
// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below
// the active level. This is the dominant pattern in inside.go/outside.go and
// what we pay on every packet.
func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) {
l := slog.New(slog.DiscardHandler).With(
"subsystem", "bench",
"localIndex", 1234,
)
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if l.Enabled(ctx, slog.LevelDebug) {
l.Debug("hello", "i", i)
}
}
}
func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) {
l := NewLogger(io.Discard).With(
"subsystem", "bench",
"localIndex", 1234,
)
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if l.Enabled(ctx, slog.LevelDebug) {
l.Debug("hello", "i", i)
}
}
}

83
main.go
View File

@@ -3,13 +3,13 @@ package nebula
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"runtime/debug"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/sshd"
@@ -20,7 +20,7 @@ import (
type m = map[string]any
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background())
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
defer func() {
@@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
buildVersion = moduleVersion()
}
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
}
// Print the config if in test, the exit comes later
if configTest {
b, err := yaml.Marshal(c.Settings)
@@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
// Print the final config
l.Println(string(b))
l.Info(string(b))
}
err := configLogger(l, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
}
c.RegisterReloadCallback(func(c *config.C) {
err := configLogger(l, c)
if err != nil {
l.WithError(err).Error("Failed to configure the logger")
}
})
pki, err := NewPKIFromConfig(l, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
@@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
}
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
}
@@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err)
sshStart = nil
}
}
@@ -99,19 +82,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
routines = 1
}
if routines > 1 {
l.WithField("routines", routines).Info("Using multiple routines")
l.Info("Using multiple routines", "routines", routines)
}
} else {
// deprecated and undocumented
tunQueues := c.GetInt("tun.routines", 1)
udpQueues := c.GetInt("listen.routines", 1)
if tunQueues > udpQueues {
routines = tunQueues
} else {
routines = udpQueues
}
routines = max(tunQueues, udpQueues)
if routines != 1 {
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines)
}
}
@@ -124,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
conntrackCacheTimeout = 1 * time.Second
}
if conntrackCacheTimeout > 0 {
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout)
}
var tun overlay.Device
@@ -170,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
@@ -219,13 +198,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
serveDns := false
if c.GetBool("lighthouse.serve_dns", false) {
if c.GetBool("lighthouse.am_lighthouse", false) {
serveDns = true
} else {
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
}
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
if err != nil {
l.Warn("Failed to start DNS responder", "error", err)
}
ifConfig := &InterfaceConfig{
@@ -234,7 +209,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Outside: udpConns[0],
pki: pki,
Firewall: fw,
ServeDns: serveDns,
DnsServer: ds,
HandshakeManager: handshakeManager,
connectionManager: connManager,
lightHouse: lightHouse,
@@ -271,7 +246,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
go handshakeManager.Run(ctx)
}
statsStart, err := startStats(l, c, buildVersion, configTest)
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
}
@@ -284,23 +259,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
attachCommands(l, c, ssh, ifce)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
var dnsStart func()
if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server")
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
}
return &Control{
ifce,
l,
ctx,
cancel,
sshStart,
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
connManager.Start,
state: StateReady,
f: ifce,
l: l,
ctx: ctx,
cancel: cancel,
sshStart: sshStart,
statsStart: stats.Start,
dnsStart: ds.Start,
lighthouseStart: lightHouse.StartUpdateWorker,
connectionManagerStart: connManager.Start,
}, nil
}

View File

@@ -15,14 +15,12 @@ type endianness interface {
var noiseEndianness endianness = binary.BigEndian
type NebulaCipherState struct {
c noise.Cipher
//k [32]byte
//n uint64
c cipher.AEAD
}
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
return &NebulaCipherState{c: s.Cipher()}
x := s.Cipher()
return &NebulaCipherState{c: x.(cipher.AEAD)}
}
type cipherAEADDanger interface {
@@ -40,10 +38,6 @@ type cipherAEADDanger interface {
// be re-used by callers to minimize garbage collection.
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
switch ce := s.c.(type) {
case cipherAEADDanger:
return ce.EncryptDanger(out, ad, plaintext, n, nb)
default:
// TODO: Is this okay now that we have made messageCounter atomic?
// Alternative may be to split the counter space into ranges
//if n <= s.n {
@@ -55,10 +49,9 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad)
out = s.c.Seal(out, nb, plaintext, ad)
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
return out, nil
}
} else {
return nil, errors.New("no cipher state available to encrypt")
}
@@ -66,17 +59,12 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n
func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
switch ce := s.c.(type) {
case cipherAEADDanger:
return ce.DecryptDanger(out, ad, ciphertext, n, nb)
default:
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad)
}
return s.c.Open(out, nb, ciphertext, ad)
} else {
return []byte{}, nil
}
@@ -84,7 +72,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64,
func (s *NebulaCipherState) Overhead() int {
if s != nil {
return s.c.(cipher.AEAD).Overhead()
return s.c.Overhead()
}
return 0
}

View File

@@ -1,5 +1,4 @@
//go:build !boringcrypto
// +build !boringcrypto
package nebula

View File

@@ -1,15 +1,16 @@
package nebula
import (
"context"
"encoding/binary"
"errors"
"log/slog"
"net/netip"
"time"
"github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/net/ipv4"
@@ -24,7 +25,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
f.l.Info("Error while parsing inbound packet",
"from", via,
"error", err,
"packet", packet,
)
}
return
}
@@ -32,8 +37,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
//l.Error("in packet ", header, packet[HeaderLen:])
if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Refusing to process double encrypted packet", "from", via)
}
return
}
@@ -87,7 +92,10 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
"vpnAddrs", hostinfo.vpnAddrs,
"remoteIndex", h.RemoteIndex,
)
return
}
@@ -108,7 +116,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
"relayTo", relay.PeerAddr,
"error", err,
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
)
return
}
@@ -124,7 +136,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
hostinfo.logger(f.l).Info("Unexpected target relay state",
"relayTo", relay.PeerAddr,
"relayFrom", hostinfo.vpnAddrs[0],
"targetRelayState", targetRelay.State,
)
return
}
}
@@ -138,9 +154,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
@@ -157,9 +175,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt test packet")
hostinfo.logger(f.l).Error("Failed to decrypt test packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
@@ -190,9 +210,17 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if !f.handleEncrypted(ci, via, h) {
return
}
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
hostinfo.logger(f.l).WithField("from", via).
Info("Close tunnel received, tearing down.")
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
f.closeTunnel(hostinfo)
return
@@ -204,9 +232,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt Control packet")
hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
@@ -214,7 +244,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via)
}
return
}
@@ -240,20 +272,27 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr)
}
return
}
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote",
"suppressSeconds", RoamingSuppressSeconds,
"udpAddr", hostinfo.remote,
"newAddr", via.UdpAddr,
)
}
return
}
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Info("Host roamed to new udp ip/port.")
hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.",
"udpAddr", hostinfo.remote,
"newAddr", via.UdpAddr,
)
hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(via.UdpAddr)
@@ -327,13 +366,29 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
proto := layers.IPProtocol(data[protoAt])
switch proto {
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
case layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = false
return nil
case layers.IPProtocolICMPv6:
if dataLen < offset+6 {
return ErrIPv6PacketTooShort
}
fp.Protocol = uint8(proto)
fp.LocalPort = 0 //incoming vs outgoing doesn't matter for icmpv6
icmptype := data[offset+1]
switch icmptype {
case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply:
fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier
default:
fp.RemotePort = 0
}
fp.Fragment = false
return nil
case layers.IPProtocolTCP, layers.IPProtocolUDP:
if dataLen < offset+4 {
return ErrIPv6PacketTooShort
@@ -423,34 +478,38 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
minLen := ihl
if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
if !fp.Fragment {
if fp.Protocol == firewall.ProtoICMP {
minLen += minFwPacketLen + 2
} else {
minLen += minFwPacketLen
}
}
if len(data) < minLen {
return ErrIPv4InvalidHeaderLength
}
// Firewall packets are locally oriented
if incoming {
if incoming { // Firewall packets are locally oriented
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
} else {
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
}
if fp.Fragment {
fp.RemotePort = 0
fp.LocalPort = 0
} else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
fp.LocalPort = 0 //code would be uint16(data[ihl+1])
} else if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
}
return nil
@@ -464,8 +523,9 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
}
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
hostinfo.logger(f.l).WithField("header", h).
Debugln("dropping out of window packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h)
}
return nil, errors.New("out of window packet")
}
@@ -477,20 +537,23 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
return false
}
err = newPacket(out, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
"error", err,
"packet", out,
)
return false
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket)
}
return false
}
@@ -499,10 +562,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping inbound packet",
"fwPacket", fwPacket,
"reason", dropReason,
)
}
return false
}
@@ -510,7 +574,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
f.l.Error("Failed to write to tun", "error", err)
}
return true
}
@@ -526,35 +590,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
_ = f.outside.WriteTo(b, endpoint)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index).
WithField("udpAddr", endpoint).
Debug("Recv error sent")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Recv error sent",
"index", index,
"udpAddr", endpoint,
)
}
}
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
Debug("Recv error received, ignoring")
f.l.Debug("Recv error received, ignoring",
"index", h.RemoteIndex,
"udpAddr", addr,
)
return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
Debug("Recv error received")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Recv error received",
"index", h.RemoteIndex,
"udpAddr", addr,
)
}
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if hostinfo == nil {
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex)
return
}
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
f.l.Info("Someone spoofing recv_errors?",
"addr", addr,
"hostinfoRemote", hostinfo.remote,
)
return
}

View File

@@ -155,6 +155,7 @@ func Test_newPacket_v6(t *testing.T) {
// next layer, missing length byte
err = newPacket(buffer.Bytes()[:49], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
err = nil
// A good ICMP packet
ip = layers.IPv6{
@@ -165,20 +166,26 @@ func Test_newPacket_v6(t *testing.T) {
DstIP: net.IPv6linklocalallnodes,
}
icmp := layers.ICMPv6{}
buffer.Clear()
err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
if err != nil {
panic(err)
icmp := layers.ICMPv6{
TypeCode: layers.ICMPv6TypeEchoRequest,
Checksum: 0x1234,
}
err = newPacket(buffer.Bytes(), true, p)
require.NoError(t, err)
buffer.Clear()
require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp))
require.Error(t, newPacket(buffer.Bytes(), true, p))
buffer.Clear()
echo := layers.ICMPv6Echo{
Identifier: 0xabcd,
SeqNumber: 1234,
}
require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp, &echo))
require.NoError(t, newPacket(buffer.Bytes(), true, p))
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0xabcd), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.False(t, p.Fragment)
@@ -574,7 +581,7 @@ func BenchmarkParseV6(b *testing.B) {
}
evilBytes := buffer.Bytes()
for i := 0; i < 200; i++ {
for range 200 {
evilBytes = append(evilBytes, hopHeader...)
}
evilBytes = append(evilBytes, lastHopHeader...)

View File

@@ -1,4 +1,6 @@
package test
// Package overlaytest provides fakes of overlay.Device for tests that do
// not want to touch a real tun device or route table.
package overlaytest
import (
"errors"
@@ -8,6 +10,9 @@ import (
"github.com/slackhq/nebula/routing"
)
// NoopTun is an overlay.Device that silently discards every read and write.
// Useful in tests that need to construct a nebula Interface but do not
// exercise the datapath.
type NoopTun struct{}
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {

View File

@@ -2,6 +2,7 @@ package overlay
import (
"fmt"
"log/slog"
"math"
"net"
"net/netip"
@@ -9,7 +10,6 @@ import (
"strconv"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
@@ -48,11 +48,14 @@ func (r Route) String() string {
return s
}
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
routeTree := new(bart.Table[routing.Gateways])
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
l.Warn("route MTU is not supported on this platform",
"goos", runtime.GOOS,
"route", r,
)
}
gateways := r.Via

View File

@@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) {
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true)
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
require.NoError(t, err)
ip, err := netip.ParseAddr("1.0.0.2")
@@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 3)
routeTree, err := makeRouteTree(l, routes, true)
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
require.NoError(t, err)
ip, err := netip.ParseAddr("192.168.86.1")

View File

@@ -2,10 +2,10 @@ package overlay
import (
"fmt"
"log/slog"
"net"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
)
@@ -22,9 +22,9 @@ func (e *NameError) Error() string {
}
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
@@ -36,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks)
}
}

View File

@@ -6,12 +6,12 @@ package overlay
import (
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -23,10 +23,10 @@ type tun struct {
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"sync/atomic"
@@ -14,7 +15,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -30,7 +30,7 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
l *slog.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
@@ -79,7 +79,7 @@ type ifreqAlias6 struct {
Lifetime addrLifetime
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
@@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
@@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error {
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr)
} else {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
@@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error {
}
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -1,13 +1,14 @@
package overlay
import (
"context"
"fmt"
"io"
"log/slog"
"net/netip"
"strings"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/routing"
)
@@ -19,10 +20,10 @@ type disabledTun struct {
// Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter
rx metrics.Counter
l *logrus.Logger
l *slog.Logger
}
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
tun := &disabledTun{
vpnNetworks: vpnNetworks,
read: make(chan []byte, queueLen),
@@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
}
t.tx.Inc(1)
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Write payload", "raw", prettyPacket(r))
}
return copy(b, r), nil
@@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
select {
case t.read <- out:
default:
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
t.l.Debug("tun_disabled: dropped ICMP Echo Reply response")
}
return true
@@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
// Check for ICMP Echo Request before spending time doing the full parsing
if t.handleICMPEchoRequest(b) {
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b))
}
} else if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
} else if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b))
}
return len(b), nil
}

View File

@@ -0,0 +1,120 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package overlay
import (
"errors"
"os"
"sync"
"testing"
"time"
"golang.org/x/sys/unix"
)
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
func newReadPipe(t *testing.T) int {
t.Helper()
var fds [2]int
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
t.Fatalf("pipe2: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fds[1]) })
return fds[0]
}
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
t.Cleanup(func() { _ = tf.Close() })
done := make(chan error, 1)
go func() {
_, err := tf.Read(make([]byte, 64))
done <- err
}()
// Verify Read is actually blocked in poll.
select {
case err := <-done:
t.Fatalf("Read returned before shutdown signal: %v", err)
case <-time.After(50 * time.Millisecond):
}
if err := tf.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
select {
case err := <-done:
if !errors.Is(err, os.ErrClosed) {
t.Fatalf("expected os.ErrClosed, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Read did not wake on shutdown")
}
}
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
parent, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
friend, err := parent.newFriend(newReadPipe(t))
if err != nil {
_ = parent.Close()
t.Fatalf("newFriend: %v", err)
}
t.Cleanup(func() {
_ = friend.Close()
_ = parent.Close()
})
readers := []*tunFile{parent, friend}
errs := make([]error, len(readers))
var wg sync.WaitGroup
for i, r := range readers {
wg.Add(1)
go func(i int, r *tunFile) {
defer wg.Done()
_, errs[i] = r.Read(make([]byte, 64))
}(i, r)
}
time.Sleep(50 * time.Millisecond)
if err := parent.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("readers did not wake")
}
for i, err := range errs {
if !errors.Is(err, os.ErrClosed) {
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
}
}
}
func TestTunFile_Close_Idempotent(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("first Close: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("second Close should be a no-op, got %v", err)
}
}

View File

@@ -9,15 +9,18 @@ import (
"fmt"
"io"
"io/fs"
"log/slog"
"net/netip"
"os"
"sync/atomic"
"syscall"
"time"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
@@ -92,87 +95,165 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
devFd int
l *slog.Logger
fd int
shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls
shutdownW int // write end of the shutdown pipe; closing this signals shutdown to any blocked reader/writer
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed atomic.Bool
}
// blockOnRead waits until the tun fd is readable or shutdown has been signaled.
// Returns os.ErrClosed if Close was called.
func (t *tun) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
tunEvents := t.readPoll[0].Revents
shutdownEvents := t.readPoll[1].Revents
t.readPoll[0].Revents = 0
t.readPoll[1].Revents = 0
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *tun) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
tunEvents := t.writePoll[0].Revents
shutdownEvents := t.writePoll[1].Revents
t.writePoll[0].Revents = 0
t.writePoll[1].Revents = 0
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
var head [4]byte
iovecs := [2]syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
for {
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2)
if errno == 0 {
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
if bytesRead < 4 {
return 0, nil
}
return bytesRead - 4, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnRead(); err != nil {
return 0, err
} else {
return bytesRead - 4, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
var head [4]byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
switch ipVer {
case 4:
head[3] = syscall.AF_INET
case 6:
head[3] = syscall.AF_INET6
default:
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
iovecs := [2]syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
for {
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2)
if errno == 0 {
return int(n) - 4, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnWrite(); err != nil {
return 0, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
return int(n) - 4, err
}
func (t *tun) Close() error {
if t.devFd >= 0 {
err := syscall.Close(t.devFd)
if err != nil {
t.l.WithError(err).Error("Error closing device")
if t.closed.Swap(true) {
return nil
}
// Closing the write end of the shutdown pipe causes any blocked Poll to
// return with POLLHUP on the shutdown fd, so readers/writers wake up and
// exit with os.ErrClosed.
if t.shutdownW >= 0 {
_ = unix.Close(t.shutdownW)
t.shutdownW = -1
}
if t.fd >= 0 {
if err := unix.Close(t.fd); err != nil {
t.l.Error("Error closing device", "error", err)
}
t.fd = -1
}
if t.shutdownR >= 0 {
_ = unix.Close(t.shutdownR)
t.shutdownR = -1
}
t.devFd = -1
c := make(chan struct{})
go func() {
@@ -185,7 +266,7 @@ func (t *tun) Close() error {
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
}
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
t.l.Error("Error destroying tunnel", "error", err)
}
}()
@@ -194,31 +275,52 @@ func (t *tun) Close() error {
case <-c:
case <-time.After(1 * time.Second):
}
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var fd int
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
fd, err = unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
}
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
fd, err = unix.Open("/dev/tun", os.O_RDWR, 0)
}
if err != nil {
return nil, err
}
if err = unix.SetNonblock(fd, true); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("failed to set tun device as nonblocking: %w", err)
}
// Shutdown pipe lets Close wake any reader/writer blocked in Poll.
var pipeFds [2]int
if err = unix.Pipe2(pipeFds[:], unix.O_CLOEXEC|unix.O_NONBLOCK); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("failed to create shutdown pipe: %w", err)
}
shutdownR, shutdownW := pipeFds[0], pipeFds[1]
closeOnErr := true
defer func() {
if closeOnErr {
_ = unix.Close(fd)
_ = unix.Close(shutdownR)
_ = unix.Close(shutdownW)
}
}()
// Read the name of the interface
var name [16]byte
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
@@ -237,7 +339,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
if ctrlErr != nil {
return nil, err
return nil, ctrlErr
}
ifName := string(bytes.TrimRight(name[:], "\x00"))
@@ -253,8 +355,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
defer syscall.Close(s)
fd := uintptr(s)
var fromName [16]byte
var toName [16]byte
copy(fromName[:], ifName)
@@ -266,7 +366,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
// Set the device name
_ = ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
_ = ioctl(uintptr(s), syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
}
t := &tun{
@@ -274,13 +374,24 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
devFd: fd,
fd: fd,
shutdownR: shutdownR,
shutdownW: shutdownW,
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownR), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownR), Events: unix.POLLIN},
},
}
err = t.reload(c, true)
if err != nil {
return nil, err
}
closeOnErr = false
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
@@ -475,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -490,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"sync"
@@ -14,7 +15,6 @@ import (
"syscall"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -25,14 +25,14 @@ type tun struct {
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
vpnNetworks: vpnNetworks,

View File

@@ -4,8 +4,10 @@
package overlay
import (
"encoding/binary"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
"os"
@@ -16,7 +18,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -24,9 +25,175 @@ import (
"golang.org/x/sys/unix"
)
type tun struct {
io.ReadWriteCloser
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct {
fd int
shutdownFd int
lastOne bool
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed bool
}
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
func (r *tunFile) newFriend(fd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
return &tunFile{
fd: fd,
shutdownFd: r.shutdownFd,
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
}, nil
}
func newTunFd(fd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &tunFile{
fd: fd,
shutdownFd: shutdownFd,
lastOne: true,
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
}
return out, nil
}
func (r *tunFile) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) Read(buf []byte) (int, error) {
for {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnRead(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
func (r *tunFile) Write(buf []byte) (int, error) {
for {
if n, err := unix.Write(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnWrite(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
func (r *tunFile) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
return err
}
func (r *tunFile) Close() error {
if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem
return nil
}
r.closed = true
if r.lastOne {
_ = unix.Close(r.shutdownFd)
}
return unix.Close(r.fd)
}
type tun struct {
*tunFile
readers []*tunFile
closeLock sync.Mutex
Device string
vpnNetworks []netip.Prefix
MaxMTU int
@@ -46,7 +213,7 @@ type tun struct {
routesFromSystem map[netip.Prefix]routing.Gateways
routesFromSystemLock sync.Mutex
l *logrus.Logger
l *slog.Logger
}
func (t *tun) Networks() []netip.Prefix {
@@ -71,10 +238,8 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
if err != nil {
return nil, err
}
@@ -84,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -115,6 +280,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
nameStr := c.GetString("tun.dev", "")
copy(req.Name[:], nameStr)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
_ = unix.Close(fd)
return nil, &NameError{
Name: nameStr,
Underlying: err,
@@ -122,8 +288,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
}
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
t, err := newTunGeneric(c, l, fd, vpnNetworks)
if err != nil {
return nil, err
}
@@ -133,10 +298,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
tfd, err := newTunFd(fd)
if err != nil {
_ = unix.Close(fd)
return nil, err
}
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
tunFile: tfd,
readers: []*tunFile{tfd},
closeLock: sync.Mutex{},
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
@@ -145,8 +317,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
l: l,
}
err := t.reload(c, true)
if err != nil {
if err = t.reload(c, true); err != nil {
_ = t.Close()
return nil, err
}
@@ -206,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error {
if !initial {
if oldMaxMTU != newMaxMTU {
t.setMTU()
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU)
}
if oldDefaultMTU != newDefaultMTU {
for i := range t.vpnNetworks {
err := t.setDefaultRoute(t.vpnNetworks[i])
if err != nil {
t.l.Warn(err)
t.l.Warn(err.Error())
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU)
}
}
}
@@ -239,6 +411,9 @@ func (t *tun) SupportsMultiqueue() bool {
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
t.closeLock.Lock()
defer t.closeLock.Unlock()
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
@@ -248,12 +423,19 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
_ = unix.Close(fd)
return nil, err
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
out, err := t.tunFile.newFriend(fd)
if err != nil {
_ = unix.Close(fd)
return nil, err
}
return file, nil
t.readers = append(t.readers, out)
return out, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -261,29 +443,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r
}
func (t *tun) Write(b []byte) (int, error) {
var nn int
maximum := len(b)
for {
n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 {
nn += n
}
if nn == len(b) {
return nn, err
}
if err != nil {
return nn, err
}
if n == 0 {
return nn, io.ErrUnexpectedEOF
}
}
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
@@ -333,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error {
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
t.l.Error("failed to remove address from tun address list", "error", err)
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
t.l.Info("removed address not listed in cert(s)", "removed", al[i].String())
}
}
@@ -379,12 +538,12 @@ func (t *tun) Activate() error {
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
t.l.WithError(err).Error("Failed to set tun tx queue length")
t.l.Error("Failed to set tun tx queue length", "error", err)
}
const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation")
t.l.Warn("Failed to disable link local address generation", "error", err)
}
if err = t.addIPs(link); err != nil {
@@ -423,7 +582,7 @@ func (t *tun) setMTU() {
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
t.l.WithError(err).Error("Failed to set tun mtu")
t.l.Error("Failed to set tun mtu", "error", err)
}
}
@@ -446,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
}
err := netlink.RouteReplace(&nr)
if err != nil {
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
for i := 0; i < 2; i++ {
time.Sleep(100 * time.Millisecond)
@@ -454,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
if err == nil {
break
} else {
t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
t.l.Warn("Failed to set default route MTU, retrying",
"error", err,
"cidr", cidr,
"mtu", t.DefaultMTU,
)
}
}
if err != nil {
@@ -499,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -531,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) {
err := netlink.RouteDel(&nr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
}
@@ -562,11 +725,11 @@ func (t *tun) watchRoutes() {
netlinkOptions := netlink.RouteSubscribeOptions{
ReceiveBufferSize: t.useSystemRoutesBufferSize,
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) },
}
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
t.l.Error("failed to subscribe to system route changes", "error", err)
return
}
@@ -608,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device)
return gateways
}
@@ -620,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
}
}
@@ -636,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
}
}
}
@@ -671,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
t.l.Debug("Ignoring route update, no gateways", "route", r)
return
}
if r.Dst == nil {
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
t.l.Debug("Ignoring route update, no destination address", "route", r)
return
}
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
t.l.Debug("Ignoring route update, invalid destination address", "route", r)
return
}
@@ -693,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
t.routesFromSystemLock.Lock()
if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
t.l.Info("Adding route", "destination", dst, "via", gateways)
t.routesFromSystem[dst] = gateways
newTree.Insert(dst, gateways)
} else {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
t.l.Info("Removing route", "destination", dst, "via", gateways)
delete(t.routesFromSystem, dst)
newTree.Delete(dst)
}
@@ -707,18 +870,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
}
func (t *tun) Close() error {
t.closeLock.Lock()
defer t.closeLock.Unlock()
if t.routeChan != nil {
close(t.routeChan)
t.routeChan = nil
}
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
}
// Signal all readers blocked in poll to wake up and exit
_ = t.tunFile.wakeForShutdown()
if t.ioctlFd > 0 {
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
_ = unix.Close(int(t.ioctlFd))
t.ioctlFd = 0
}
return nil
for i := range t.readers {
if i == 0 {
continue //we want to close the zeroth reader last
}
err := t.readers[i].Close()
if err != nil {
t.l.Error("error closing tun reader", "reader", i, "error", err)
} else {
t.l.Info("closed tun reader", "reader", i)
}
}
//this is t.readers[0] too
err := t.tunFile.Close()
if err != nil {
t.l.Error("error closing tun reader", "reader", 0, "error", err)
} else {
t.l.Info("closed tun reader", "reader", 0)
}
return err
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"regexp"
@@ -15,7 +16,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -63,18 +63,18 @@ type tun struct {
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
f *os.File
fd int
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
@@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
l.Warn("Failed to set the tun device as nonblocking", "error", err)
}
t := &tun{
@@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"regexp"
@@ -15,7 +16,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -54,7 +54,7 @@ type tun struct {
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
f *os.File
fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata
@@ -63,11 +63,11 @@ type tun struct {
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
@@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
l.Warn("Failed to set the tun device as nonblocking", "error", err)
}
t := &tun{
@@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -4,14 +4,15 @@
package overlay
import (
"context"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
@@ -21,14 +22,14 @@ type TestTun struct {
vpnNetworks []netip.Prefix
Routes []Route
routeTree *bart.Table[routing.Gateways]
l *logrus.Logger
l *slog.Logger
closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
if err != nil {
return nil, err
@@ -49,7 +50,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}, nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}
@@ -61,8 +62,8 @@ func (t *TestTun) Send(packet []byte) {
return
}
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Tun receiving injected packet", "dataLen", len(packet))
}
t.rxPackets <- packet
}

View File

@@ -7,6 +7,7 @@ import (
"crypto"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"path/filepath"
@@ -16,7 +17,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -33,16 +33,16 @@ type winTun struct {
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
tun *wintun.NativeTun
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
err := checkWinTunExists()
if err != nil {
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
@@ -71,7 +71,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
l.Debug("Failed to create wintun device, retrying", "error", err)
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, &NameError{
@@ -170,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
if !foundDefault4 {
@@ -208,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error {
// See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -2,14 +2,14 @@ package overlay
import (
"io"
"log/slog"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return NewUserDevice(vpnNetworks)
}

33
pki.go
View File

@@ -5,6 +5,8 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
"os"
@@ -14,7 +16,6 @@ import (
"time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
@@ -23,7 +24,7 @@ import (
type PKI struct {
cs atomic.Pointer[CertState]
caPool atomic.Pointer[cert.CAPool]
l *logrus.Logger
l *slog.Logger
}
type CertState struct {
@@ -45,7 +46,7 @@ type CertState struct {
myVpnBroadcastAddrsTable *bart.Lite
}
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) {
pki := &PKI{l: l}
err := pki.reload(c, true)
if err != nil {
@@ -181,9 +182,9 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
p.cs.Store(newState)
if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
p.l.Debug("Client nebula certificate(s)", "cert", newState)
} else {
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
p.l.Info("Client certificate(s) refreshed from disk", "cert", newState)
}
return nil
}
@@ -195,7 +196,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
}
p.caPool.Store(caPool)
p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints())
return nil
}
@@ -486,32 +487,32 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
return c, b, nil
}
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
var rawCA []byte
var err error
func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) {
caPathOrPEM := c.GetString("pki.ca", "")
if caPathOrPEM == "" {
return nil, errors.New("no pki.ca path or PEM data provided")
}
if strings.Contains(caPathOrPEM, "-----BEGIN") {
rawCA = []byte(caPathOrPEM)
var caReader io.ReadCloser
var err error
if strings.Contains(caPathOrPEM, "-----BEGIN") {
caReader = io.NopCloser(strings.NewReader(caPathOrPEM))
} else {
rawCA, err = os.ReadFile(caPathOrPEM)
caReader, err = os.Open(caPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
}
}
defer caReader.Close()
caPool, err := cert.NewCAPoolFromPEM(rawCA)
caPool, err := cert.NewCAPoolFromPEMReader(caReader)
if errors.Is(err, cert.ErrExpired) {
var expired int
for _, crt := range caPool.CAs {
if crt.Certificate.Expired(time.Now()) {
expired++
l.WithField("cert", crt).Warn("expired certificate present in CA pool")
l.Warn("expired certificate present in CA pool", "cert", crt)
}
}
@@ -529,7 +530,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
caPool.BlocklistFingerprint(fp)
}
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
l.Info("Blocklisted certificates", "fingerprintCount", len(bl))
}
return caPool, nil

121
pki_hup_benchmark_test.go Normal file
View File

@@ -0,0 +1,121 @@
package nebula
import (
"bytes"
"fmt"
"net/netip"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/slackhq/nebula/cert"
cert_test "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/require"
)
func BenchmarkReloadConfigWithCAs(b *testing.B) {
prevProcs := runtime.GOMAXPROCS(1)
b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) })
for _, size := range []int{100, 250, 500, 1000, 5000} {
b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) {
l := test.NewLogger()
dir := b.TempDir()
ca, caKey, caBundle := buildCABundle(b, size)
caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle)
configBody := fmt.Sprintf(`pki:
ca: %s
cert: %s
key: %s
`, caPath, certPath, keyPath)
configPath := filepath.Join(dir, "config.yml")
require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600))
c := config.NewC(l)
require.NoError(b, c.Load(dir))
_, err := NewPKIFromConfig(test.NewLogger(), c)
require.NoError(b, err)
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
c.ReloadConfig()
}
})
}
}
func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) {
b.Helper()
require.GreaterOrEqual(b, count, 1)
before := time.Now().Add(-24 * time.Hour)
after := time.Now().Add(24 * time.Hour)
ca, _, caKey, pem := cert_test.NewTestCaCert(
cert.Version2,
cert.Curve_CURVE25519,
before,
after,
nil,
nil,
nil,
)
buf := bytes.NewBuffer(pem)
buf.Write([]byte("\n# a comment!\n"))
for i := 1; i < count; i++ {
_, _, _, extraPEM := cert_test.NewTestCaCert(
cert.Version2,
cert.Curve_CURVE25519,
time.Now(),
time.Now().Add(time.Hour),
nil,
nil,
nil,
)
buf.Write([]byte("\n# a comment!\n"))
buf.Write(extraPEM)
}
return ca, caKey, buf.Bytes()
}
func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) {
b.Helper()
networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}
_, _, keyPEM, certPEM := cert_test.NewTestCert(
cert.Version2,
cert.Curve_CURVE25519,
ca,
caKey,
"reload-benchmark",
time.Now(),
time.Now().Add(time.Hour),
networks,
nil,
nil,
)
caPath := filepath.Join(dir, "ca.pem")
certPath := filepath.Join(dir, "cert.pem")
keyPath := filepath.Join(dir, "key.pem")
require.NoError(b, os.WriteFile(caPath, caBundle, 0o600))
require.NoError(b, os.WriteFile(certPath, certPEM, 0o600))
require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600))
return caPath, certPath, keyPath
}

View File

@@ -1,10 +1,10 @@
package nebula
import (
"log/slog"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
@@ -14,10 +14,10 @@ type Punchy struct {
delay atomic.Int64
respondDelay atomic.Int64
punchEverything atomic.Bool
l *logrus.Logger
l *slog.Logger
}
func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
p := &Punchy{l: l}
p.reload(c, true)
@@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
p.respond.Store(yes)
if !initial {
p.l.Infof("punchy.respond changed to %v", p.GetRespond())
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
}
}
@@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) {
if initial || c.HasChanged("punchy.delay") {
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
if !initial {
p.l.Infof("punchy.delay changed to %s", p.GetDelay())
p.l.Info("punchy.delay changed", "delay", p.GetDelay())
}
}
if initial || c.HasChanged("punchy.target_all_remotes") {
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
if !initial {
p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed")
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
}
}
if initial || c.HasChanged("punchy.respond_delay") {
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
if !initial {
p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay())
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
}
}
}

View File

@@ -1,6 +1,8 @@
package nebula
import (
"context"
"log/slog"
"testing"
"time"
@@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) {
c := config.NewC(l)
// Test defaults
p := NewPunchyFromConfig(l, c)
p := NewPunchyFromConfig(test.NewLogger(), c)
assert.False(t, p.GetPunch())
assert.False(t, p.GetRespond())
assert.Equal(t, time.Second, p.GetDelay())
@@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) {
// punchy deprecation
c.Settings["punchy"] = true
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch())
// punchy.punch
c.Settings["punchy"] = map[string]any{"punch": true}
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch())
// punch_back deprecation
c.Settings["punch_back"] = true
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond())
// punchy.respond
c.Settings["punchy"] = map[string]any{"respond": true}
c.Settings["punch_back"] = false
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond())
// punchy.delay
c.Settings["punchy"] = map[string]any{"delay": "1m"}
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetDelay())
// punchy.respond_delay
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetRespondDelay())
}
@@ -62,7 +64,7 @@ punchy:
delay: 1m
respond: false
`))
p := NewPunchyFromConfig(l, c)
p := NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, delay, p.GetDelay())
assert.False(t, p.GetRespond())
@@ -76,3 +78,158 @@ punchy:
assert.Equal(t, newDelay, p.GetDelay())
assert.True(t, p.GetRespond())
}
// The tests below pin the shape of each log line Punchy produces so changes
// cannot silently break whatever operators are grepping for. The assertions
// are on the structured message + attrs (e.g. "punchy.respond changed" with
// a respond=true field) rather than a formatted string.
//
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
// not supported" warning whenever any key under punchy changes, because of
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
// punchy form. The tests filter by message rather than asserting total
// entry counts so that warning is tolerated without being locked into
// the format.
type capturedEntry struct {
Level slog.Level
Msg string
Attrs map[string]any
}
// capturingHandler is a slog.Handler that records each Record it receives so
// tests can assert on the level, message, and attribute map of individual log
// lines without coupling to any specific text format.
type capturingHandler struct {
entries []capturedEntry
}
func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error {
e := capturedEntry{
Level: r.Level,
Msg: r.Message,
Attrs: make(map[string]any),
}
r.Attrs(func(a slog.Attr) bool {
e.Attrs[a.Key] = a.Value.Resolve().Any()
return true
})
h.entries = append(h.entries, e)
return nil
}
func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h }
func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) {
t.Helper()
hook := &capturingHandler{}
return slog.New(hook), hook
}
func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry {
t.Helper()
for _, e := range entries {
if e.Msg == msg {
return e
}
}
t.Fatalf("no entry with message %q among %d entries", msg, len(entries))
return capturedEntry{}
}
func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
NewPunchyFromConfig(l, c)
entry := findEntry(t, hook.entries, "punchy enabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
entry := findEntry(t, hook.entries, "punchy disabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
assert.Equal(t, slog.LevelWarn, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
entry := findEntry(t, hook.entries, "punchy.respond changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"respond": true}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
entry := findEntry(t, hook.entries, "punchy.delay changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
entry := findEntry(t, hook.entries, "punchy.respond_delay changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs)
}

View File

@@ -5,22 +5,22 @@ import (
"encoding/binary"
"errors"
"fmt"
"log/slog"
"net/netip"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
)
type relayManager struct {
l *logrus.Logger
l *slog.Logger
hostmap *HostMap
amRelay atomic.Bool
}
func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager {
func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager {
rm := &relayManager{
l: l,
hostmap: hostmap,
@@ -29,7 +29,7 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c
c.RegisterReloadCallback(func(c *config.C) {
err := rm.reload(c, false)
if err != nil {
l.WithError(err).Error("Failed to reload relay_manager")
rm.l.Error("Failed to reload relay_manager", "error", err)
}
})
return rm
@@ -52,10 +52,10 @@ func (rm *relayManager) setAmRelay(v bool) {
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
hm.Lock()
defer hm.Unlock()
for i := 0; i < 32; i++ {
for range 32 {
index, err := generateIndex(l)
if err != nil {
return 0, err
@@ -92,24 +92,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
if !ok {
fields := logrus.Fields{
"relay": relayHostInfo.vpnAddrs[0],
"initiatorRelayIndex": m.InitiatorRelayIndex,
}
var relayFrom, relayTo any
if m.RelayFromAddr == nil {
fields["relayFrom"] = m.OldRelayFromAddr
relayFrom = m.OldRelayFromAddr
} else {
fields["relayFrom"] = m.RelayFromAddr
relayFrom = m.RelayFromAddr
}
if m.RelayToAddr == nil {
fields["relayTo"] = m.OldRelayToAddr
relayTo = m.OldRelayToAddr
} else {
fields["relayTo"] = m.RelayToAddr
relayTo = m.RelayToAddr
}
rm.l.WithFields(fields).Info("relayManager failed to update relay")
rm.l.Info("relayManager failed to update relay",
"relay", relayHostInfo.vpnAddrs[0],
"initiatorRelayIndex", m.InitiatorRelayIndex,
"relayFrom", relayFrom,
"relayTo", relayTo,
)
return nil, fmt.Errorf("unknown relay")
}
@@ -120,7 +120,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
msg := &NebulaControl{}
err := msg.Unmarshal(d)
if err != nil {
h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
h.logger(f.l).Error("Failed to unmarshal control message", "error", err)
return
}
@@ -147,20 +147,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
}
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{
"relayFrom": protoAddrToNetAddr(m.RelayFromAddr),
"relayTo": protoAddrToNetAddr(m.RelayToAddr),
"initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex,
"vpnAddrs": h.vpnAddrs}).
Info("handleCreateRelayResponse")
rm.l.Info("handleCreateRelayResponse",
"relayFrom", protoAddrToNetAddr(m.RelayFromAddr),
"relayTo", protoAddrToNetAddr(m.RelayToAddr),
"initiatorRelayIndex", m.InitiatorRelayIndex,
"responderRelayIndex", m.ResponderRelayIndex,
"vpnAddrs", h.vpnAddrs,
)
target := m.RelayToAddr
targetAddr := protoAddrToNetAddr(target)
relay, err := rm.EstablishRelay(h, m)
if err != nil {
rm.l.WithError(err).Error("Failed to update relay for relayTo")
rm.l.Error("Failed to update relay for relayTo", "error", err)
return
}
// Do I need to complete the relays now?
@@ -170,12 +170,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
if peerHostInfo == nil {
rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr)
return
}
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
if !ok {
rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0])
return
}
switch peerRelay.State {
@@ -193,12 +193,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
if v == cert.Version1 {
peer := peerHostInfo.vpnAddrs[0]
if !peer.Is4() {
rm.l.WithField("relayFrom", peer).
WithField("relayTo", target).
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
WithField("responderRelayIndex", resp.ResponderRelayIndex).
WithField("vpnAddrs", peerHostInfo.vpnAddrs).
Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address")
rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address",
"relayFrom", peer,
"relayTo", target,
"initiatorRelayIndex", resp.InitiatorRelayIndex,
"responderRelayIndex", resp.ResponderRelayIndex,
"vpnAddrs", peerHostInfo.vpnAddrs,
)
return
}
@@ -213,17 +214,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
msg, err := resp.Marshal()
if err != nil {
rm.l.WithError(err).
Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
} else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": resp.RelayFromAddr,
"relayTo": resp.RelayToAddr,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnAddrs": peerHostInfo.vpnAddrs}).
Info("send CreateRelayResponse")
rm.l.Info("send CreateRelayResponse",
"relayFrom", resp.RelayFromAddr,
"relayTo", resp.RelayToAddr,
"initiatorRelayIndex", resp.InitiatorRelayIndex,
"responderRelayIndex", resp.ResponderRelayIndex,
"vpnAddrs", peerHostInfo.vpnAddrs,
)
}
}
}
@@ -232,17 +232,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
from := protoAddrToNetAddr(m.RelayFromAddr)
target := protoAddrToNetAddr(m.RelayToAddr)
logMsg := rm.l.WithFields(logrus.Fields{
"relayFrom": from,
"relayTo": target,
"initiatorRelayIndex": m.InitiatorRelayIndex,
"vpnAddrs": h.vpnAddrs})
logMsg := rm.l.With(
"relayFrom", from,
"relayTo", target,
"initiatorRelayIndex", m.InitiatorRelayIndex,
"vpnAddrs", h.vpnAddrs,
)
logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects.
if f.myVpnAddrsTable.Contains(from) {
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
logMsg.Error("Discarding relay request from myself", "myIP", from)
return
}
@@ -261,37 +262,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
"existingRemoteIndex", existingRelay.RemoteIndex)
return
}
case Disestablished:
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
"existingRemoteIndex", existingRelay.RemoteIndex)
return
}
// Mark the relay as 'Established' because it's safe to use again
h.relayState.UpdateRelayForByIpState(from, Established)
case PeerRequested:
// I should never be in this state, because I am terminal, not forwarding.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex,
"state": existingRelay.State}).Error("Unexpected Relay State found")
logMsg.Error("Unexpected Relay State found",
"existingRemoteIndex", existingRelay.RemoteIndex,
"state", existingRelay.State)
}
} else {
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
if err != nil {
logMsg.WithError(err).Error("Failed to add relay")
logMsg.Error("Failed to add relay", "error", err)
return
}
}
relay, ok := h.relayState.QueryRelayForByIp(from)
if !ok {
logMsg.WithField("from", from).Error("Relay State not found")
logMsg.Error("Relay State not found", "from", from)
return
}
@@ -313,17 +314,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
msg, err := resp.Marshal()
if err != nil {
logMsg.
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
} else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": from,
"relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnAddrs": h.vpnAddrs}).
Info("send CreateRelayResponse")
rm.l.Info("send CreateRelayResponse",
"relayFrom", from,
"relayTo", target,
"initiatorRelayIndex", resp.InitiatorRelayIndex,
"responderRelayIndex", resp.ResponderRelayIndex,
"vpnAddrs", h.vpnAddrs,
)
}
return
} else {
@@ -363,12 +363,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if v == cert.Version1 {
if !h.vpnAddrs[0].Is4() {
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
WithField("relayTo", target).
WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
WithField("responderRelayIndex", req.ResponderRelayIndex).
WithField("vpnAddr", target).
Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address")
rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address",
"relayFrom", h.vpnAddrs[0],
"relayTo", target,
"initiatorRelayIndex", req.InitiatorRelayIndex,
"responderRelayIndex", req.ResponderRelayIndex,
"vpnAddr", target,
)
return
}
@@ -383,17 +384,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
msg, err := req.Marshal()
if err != nil {
logMsg.
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err)
} else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": h.vpnAddrs[0],
"relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnAddr": target}).
Info("send CreateRelayRequest")
rm.l.Info("send CreateRelayRequest",
"relayFrom", h.vpnAddrs[0],
"relayTo", target,
"initiatorRelayIndex", req.InitiatorRelayIndex,
"responderRelayIndex", req.ResponderRelayIndex,
"vpnAddr", target,
)
}
// Also track the half-created Relay state just received
@@ -401,8 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if !ok {
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
if err != nil {
logMsg.
WithError(err).Error("relayManager Failed to allocate a local index for relay")
logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err)
return
}
}

View File

@@ -2,6 +2,7 @@ package nebula
import (
"context"
"log/slog"
"net"
"net/netip"
"slices"
@@ -10,8 +11,6 @@ import (
"sync"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
)
// forEachFunc is used to benefit folks that want to do work inside the lock
@@ -66,11 +65,11 @@ type hostnamesResults struct {
network string
lookupTimeout time.Duration
cancelFn func()
l *logrus.Logger
l *slog.Logger
ips atomic.Pointer[map[netip.AddrPort]struct{}]
}
func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
r := &hostnamesResults{
hostnames: make([]hostnamePort, len(hostPorts)),
network: network,
@@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
timeoutCancel()
if err != nil {
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
l.Error("DNS resolution failed for static_map host",
"hostname", hostPort.name,
"network", r.network,
"error", err,
)
continue
}
for _, a := range addrs {
@@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
}
}
if different {
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
l.Info("DNS results changed for host list",
"origSet", origSet,
"newSet", netipAddrs,
)
r.ips.Store(&netipAddrs)
onUpdate()
}
@@ -404,12 +410,7 @@ func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
for _, v := range r.badRemotes {
if v == remote {
return true
}
}
return false
return slices.Contains(r.badRemotes, remote)
}
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the

View File

@@ -44,7 +44,10 @@ type Service struct {
}
func New(control *nebula.Control) (*Service, error) {
control.Start()
wait, err := control.Start()
if err != nil {
return nil, err
}
ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx)
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
}
})
// Add the nebula wait function to the group so a fatal reader error
// propagates out through errgroup.Wait().
eg.Go(func() error {
return wait()
})
return &s, nil
}

View File

@@ -10,11 +10,11 @@ import (
"time"
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/overlay"
"go.yaml.in/yaml/v3"
"golang.org/x/sync/errgroup"
@@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
panic(err)
}
logger := logrus.New()
logger.Out = os.Stdout
logger := logging.NewLogger(os.Stdout)
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {

161
ssh.go
View File

@@ -6,19 +6,21 @@ import (
"errors"
"flag"
"fmt"
"log/slog"
"maps"
"net"
"net/netip"
"os"
"reflect"
"path/filepath"
"runtime"
"runtime/pprof"
"sort"
"strconv"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/sshd"
)
@@ -55,12 +57,12 @@ type sshDeviceInfoFlags struct {
Pretty bool
}
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) {
c.RegisterReloadCallback(func(c *config.C) {
if c.GetBool("sshd.enabled", false) {
sshRun, err := configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Error("Failed to reconfigure the sshd")
l.Error("Failed to reconfigure the sshd", "error", err)
ssh.Stop()
}
if sshRun != nil {
@@ -76,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
// updates the passed-in SSHServer. On success, it returns a function
// that callers may invoke to run the configured ssh server. On
// failure, it returns nil, error.
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
listen := c.GetString("sshd.listen", "")
if listen == "" {
return nil, fmt.Errorf("sshd.listen must be provided")
@@ -118,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
for _, caAuthorizedKey := range rawCAs {
err := ssh.AddTrustedCA(caAuthorizedKey)
if err != nil {
l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring")
l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey)
continue
}
}
@@ -129,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
for _, rk := range keys {
kDef, ok := rk.(map[string]any)
if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk)
continue
}
user, ok := kDef["user"].(string)
if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field")
l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk)
continue
}
@@ -144,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
case string:
err := ssh.AddAuthorizedKey(user, v)
if err != nil {
l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key")
l.Warn("Failed to authorize key",
"error", err,
"sshKeyConfig", rk,
"sshKey", v,
)
continue
}
@@ -152,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
for _, subK := range v {
sk, ok := subK.(string)
if !ok {
l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key")
l.Warn("Did not understand ssh key",
"sshKeyConfig", rk,
"sshKey", subK,
)
continue
}
err := ssh.AddAuthorizedKey(user, sk)
if err != nil {
l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key")
l.Warn("Failed to authorize key",
"error", err,
"sshKeyConfig", sk,
)
continue
}
}
default:
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood")
l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk)
}
}
} else {
@@ -176,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
ssh.Stop()
runner = func() {
if err := ssh.Run(listen); err != nil {
l.WithField("err", err).Warn("Failed to run the SSH server")
l.Warn("Failed to run the SSH server", "error", err)
}
}
} else {
@@ -186,7 +198,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
return runner, nil
}
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
// sandboxDir defaults to a dir in temp. The intention is that end user will
// create this dir as needed. Overriding this config value to "" allows
// writing to anywhere in the system.
defaultDir := filepath.Join(os.TempDir(), "nebula-debug")
sandboxDir := c.GetString("sshd.sandbox_dir", defaultDir)
ssh.RegisterCommand(&sshd.Command{
Name: "list-hostmap",
ShortDescription: "List all known previously connected hosts",
@@ -245,7 +263,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "start-cpu-profile",
ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`",
Callback: sshStartCpuProfile,
Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshStartCpuProfile(sandboxDir, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
@@ -260,7 +280,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "save-heap-profile",
ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`",
Callback: sshGetHeapProfile,
Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshGetHeapProfile(sandboxDir, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
@@ -272,7 +294,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "save-mutex-profile",
ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`",
Callback: sshGetMutexProfile,
Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshGetMutexProfile(sandboxDir, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
@@ -505,13 +529,43 @@ func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) er
return nil
}
func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
// sshSanitizeFilePath validates that the given file path is within the sandbox directory.
// If sandboxDir is empty, the path is returned as-is for backwards compatibility.
func sshSanitizeFilePath(sandboxDir, filePath string) (string, error) {
if sandboxDir == "" {
return filePath, nil
}
// Clean and resolve the path relative to the sandbox directory
if !filepath.IsAbs(filePath) {
filePath = filepath.Join(sandboxDir, filePath)
}
cleaned := filepath.Clean(filePath)
// Ensure the resolved path is within the sandbox directory
cleanedSandbox := filepath.Clean(sandboxDir)
if cleaned == cleanedSandbox {
return "", fmt.Errorf("path %q resolves to the sandbox directory itself %q", filePath, sandboxDir)
}
if !strings.HasPrefix(cleaned, cleanedSandbox+string(filepath.Separator)) {
return "", fmt.Errorf("path %q is outside the sandbox directory %q", filePath, sandboxDir)
}
return cleaned, nil
}
func sshStartCpuProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
err := w.WriteLine("No path to write profile provided")
return err
}
file, err := os.Create(a[0])
filePath, err := sshSanitizeFilePath(sandboxDir, a[0])
if err != nil {
return w.WriteLine(err.Error())
}
file, err := os.Create(filePath)
if err != nil {
err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
return err
@@ -675,12 +729,17 @@ func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
return w.WriteLine("Changed")
}
func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error {
func sshGetHeapProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine("No path to write profile provided")
}
file, err := os.Create(a[0])
filePath, err := sshSanitizeFilePath(sandboxDir, a[0])
if err != nil {
return w.WriteLine(err.Error())
}
file, err := os.Create(filePath)
if err != nil {
err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
return err
@@ -711,12 +770,17 @@ func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
}
func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine("No path to write profile provided")
}
file, err := os.Create(a[0])
filePath, err := sshSanitizeFilePath(sandboxDir, a[0])
if err != nil {
return w.WriteLine(err.Error())
}
file, err := os.Create(filePath)
if err != nil {
return w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
}
@@ -735,36 +799,45 @@ func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
}
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
ctrl, ok := l.Handler().(interface {
GetLevel() slog.Level
SetLevel(slog.Level)
})
if !ok {
return w.WriteLine("Log level is not reconfigurable on this logger")
}
level, err := logrus.ParseLevel(a[0])
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
}
level, err := logging.ParseLevel(strings.ToLower(a[0]))
if err != nil {
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels))
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a))
}
l.SetLevel(level)
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
ctrl.SetLevel(level)
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
}
func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
ctrl, ok := l.Handler().(interface {
GetFormat() string
SetFormat(string) error
})
if !ok {
return w.WriteLine("Log format is not reconfigurable on this logger")
}
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
}
logFormat := strings.ToLower(a[0])
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{}
case "json":
l.Formatter = &logrus.JSONFormatter{}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil {
return err
}
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
}
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
@@ -831,9 +904,7 @@ func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
relays := map[uint32]*HostInfo{}
ifce.hostMap.Lock()
for k, v := range ifce.hostMap.Relays {
relays[k] = v
}
maps.Copy(relays, ifce.hostMap.Relays)
ifce.hostMap.Unlock()
type RelayFor struct {

View File

@@ -2,19 +2,19 @@ package sshd
import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"net"
"sync"
"github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
type SSHServer struct {
config *ssh.ServerConfig
l *logrus.Entry
l *slog.Logger
certChecker *ssh.CertChecker
@@ -27,20 +27,21 @@ type SSHServer struct {
commands *radix.Tree
listener net.Listener
// Locks the conns/counter to avoid concurrent map access
connsLock sync.Mutex
conns map[int]*session
counter int
// Call the cancel() function to stop all active sessions
ctx context.Context
cancel func()
}
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
ctx, cancel := context.WithCancel(context.Background())
s := &SSHServer{
trustedKeys: make(map[string]map[string]bool),
l: l,
commands: radix.New(),
conns: make(map[int]*session),
ctx: ctx,
cancel: cancel,
}
cc := ssh.CertChecker{
@@ -120,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error {
}
s.trustedCAs = append(s.trustedCAs, pk)
s.l.WithField("sshKey", pubKey).Info("Trusted CA key")
s.l.Info("Trusted CA key", "sshKey", pubKey)
return nil
}
@@ -138,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
}
tk[string(pk.Marshal())] = true
s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
s.l.Info("Authorized ssh key",
"sshKey", pubKey,
"sshUser", user,
)
return nil
}
@@ -155,7 +159,7 @@ func (s *SSHServer) Run(addr string) error {
return err
}
s.l.WithField("sshListener", addr).Info("SSH server is listening")
s.l.Info("SSH server is listening", "sshListener", addr)
// Run loops until there is an error
s.run()
@@ -171,11 +175,20 @@ func (s *SSHServer) run() {
c, err := s.listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
s.l.WithError(err).Warn("Error in listener, shutting down")
s.l.Warn("Error in listener, shutting down", "error", err)
}
return
}
go func(c net.Conn) {
// NewServerConn may block while waiting for the client to complete the handshake.
// Ensure that a bad client doesn't hurt us by checking for the parent context
// cancellation before calling NewServerConn, and forcing the socket to close when
// the context is cancelled.
sessionContext, sessionCancel := context.WithCancel(s.ctx)
go func() {
<-sessionContext.Done()
c.Close()
}()
conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
fp := ""
if conn != nil {
@@ -183,36 +196,33 @@ func (s *SSHServer) run() {
}
if err != nil {
l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
l := s.l.With(
"error", err,
"remoteAddress", c.RemoteAddr(),
)
if conn != nil {
l = l.WithField("sshUser", conn.User())
l = l.With("sshUser", conn.User())
conn.Close()
}
if fp != "" {
l = l.WithField("sshFingerprint", fp)
l = l.With("sshFingerprint", fp)
}
l.Warn("failed to handshake")
continue
sessionCancel()
return
}
l := s.l.WithField("sshUser", conn.User())
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
l := s.l.With("sshUser", conn.User())
l.Info("ssh user logged in",
"remoteAddress", c.RemoteAddr(),
"sshFingerprint", fp,
)
session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
s.connsLock.Lock()
s.counter++
counter := s.counter
s.conns[counter] = session
s.connsLock.Unlock()
NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session"))
go ssh.DiscardRequests(reqs)
go func() {
<-session.exitChan
s.l.WithField("id", counter).Debug("closing conn")
s.connsLock.Lock()
delete(s.conns, counter)
s.connsLock.Unlock()
}()
}(c)
}
}
@@ -220,15 +230,11 @@ func (s *SSHServer) Stop() {
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
if s.listener != nil {
if err := s.listener.Close(); err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener")
s.l.Warn("Failed to close the sshd listener", "error", err)
}
}
}
func (s *SSHServer) closeSessions() {
s.connsLock.Lock()
for _, c := range s.conns {
c.Close()
}
s.connsLock.Unlock()
s.cancel()
}

Some files were not shown because too many files have changed in this diff Show More