Merge remote-tracking branch 'origin/master' into multiport

This commit is contained in:
Wade Simmons
2026-05-06 14:26:49 -04:00
138 changed files with 10562 additions and 4541 deletions

View File

@@ -209,10 +209,11 @@ jobs:
id: create_release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_REF_NAME: ${{ github.ref_name }}
run: |
cd artifacts
gh release create \
--verify-tag \
--title "Release ${{ github.ref_name }}" \
"${{ github.ref_name }}" \
--title "Release ${GITHUB_REF_NAME}" \
"${GITHUB_REF_NAME}" \
SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz

View File

@@ -18,6 +18,8 @@ jobs:
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
name: Run extra smoke tests
runs-on: ubuntu-latest
env:
VAGRANT_DEFAULT_PROVIDER: libvirt
steps:
- uses: actions/checkout@v6
@@ -30,8 +32,13 @@ jobs:
- name: add hashicorp source
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
- name: install vagrant
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
- name: install vagrant and libvirt
run: |
sudo apt-get update && sudo apt-get install -y vagrant libvirt-daemon-system libvirt-dev
sudo chmod 666 /dev/kvm
sudo usermod -aG libvirt $(whoami)
sudo chmod 666 /var/run/libvirt/libvirt-sock
vagrant plugin install vagrant-libvirt
- name: freebsd-amd64
run: make smoke-vagrant/freebsd-amd64
@@ -42,10 +49,19 @@ jobs:
- name: netbsd-amd64
run: make smoke-vagrant/netbsd-amd64
- name: linux-386
run: make smoke-vagrant/linux-386
- name: linux-amd64-ipv6disable
run: make smoke-vagrant/linux-amd64-ipv6disable
# linux-386 runs last because it requires disabling KVM to use VirtualBox,
# which prevents libvirt (used by the other tests) from working after this point.
- name: install virtualbox for i386 test
run: |
sudo apt-get install -y virtualbox
sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true
- name: linux-386
env:
VAGRANT_DEFAULT_PROVIDER: virtualbox
run: make smoke-vagrant/linux-386
timeout-minutes: 30

View File

@@ -16,8 +16,10 @@ relay:
am_relay: true
EOF
export LIGHTHOUSES="192.168.100.1 172.17.0.2:4242"
export REMOTE_ALLOW_LIST='{"172.17.0.4/32": false, "172.17.0.5/32": false}'
# TEST-NET-3 placeholder IPs; smoke-relay.sh seds them to real container IPs.
# Mapping: .2 lighthouse1, .3 host2, .4 host3, .5 host4.
export LIGHTHOUSES="192.168.100.1 203.0.113.2:4242"
export REMOTE_ALLOW_LIST='{"203.0.113.4/32": false, "203.0.113.5/32": false}'
HOST="host2" ../genconfig.sh >host2.yml <<EOF
relay:
@@ -25,7 +27,7 @@ relay:
- 192.168.100.1
EOF
export REMOTE_ALLOW_LIST='{"172.17.0.3/32": false}'
export REMOTE_ALLOW_LIST='{"203.0.113.3/32": false}'
HOST="host3" ../genconfig.sh >host3.yml

View File

@@ -5,9 +5,15 @@ set -e -x
rm -rf ./build
mkdir ./build
# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
# - We could make this better by launching the lighthouse first and then fetching what IP it is.
NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
# Smoke containers run on a dedicated docker network whose subnet is allocated
# at smoke time, not known at build time. Configs are written with TEST-NET-3
# placeholder IPs (RFC 5737) and smoke.sh / smoke-vagrant.sh / smoke-relay.sh
# sed the real container IPs in before starting nebula.
#
# Placeholder mapping (last octet == fixed container slot):
# 203.0.113.2 -> lighthouse1, 203.0.113.3 -> host2,
# 203.0.113.4 -> host3, 203.0.113.5 -> host4.
LIGHTHOUSE_IP="203.0.113.2"
(
cd build
@@ -25,16 +31,16 @@ NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{
../genconfig.sh >lighthouse1.yml
HOST="host2" \
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
../genconfig.sh >host2.yml
HOST="host3" \
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host3.yml
HOST="host4" \
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host4.yml

View File

@@ -6,6 +6,8 @@ set -o pipefail
mkdir -p logs
NETWORK="nebula-smoke-relay"
cleanup() {
echo
echo " *** cleanup"
@@ -16,22 +18,53 @@ cleanup() {
then
docker kill lighthouse1 host2 host3 host4
fi
docker network rm "$NETWORK" >/dev/null 2>&1
}
trap cleanup EXIT
docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
# Create a dedicated smoke network with an explicit subnet (required for --ip
# below). Probe a short list of candidates so a locally-used range doesn't
# fail the whole test — we only need one to be free.
docker network rm "$NETWORK" >/dev/null 2>&1 || true
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
break
fi
done
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
exit 1
fi
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
# .3 host2, .4 host3, .5 host4 — matches the placeholders in build-relay.sh.
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
PREFIX="${SUBNET%/*}"
PREFIX="${PREFIX%.*}"
LIGHTHOUSE_IP="$PREFIX.2"
HOST2_IP="$PREFIX.3"
HOST3_IP="$PREFIX.4"
HOST4_IP="$PREFIX.5"
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
for f in build/host2.yml build/host3.yml build/host4.yml; do
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
mv "$f.tmp" "$f"
done
docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" nebula:smoke-relay -config host2.yml -test
docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" nebula:smoke-relay -config host3.yml -test
docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" nebula:smoke-relay -config host4.yml -test
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 1
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
sleep 1
set +x
@@ -76,7 +109,13 @@ docker exec host4 sh -c 'kill 1'
docker exec host3 sh -c 'kill 1'
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
sleep 5
# Wait up to 30s for all backgrounded jobs to exit rather than relying on a
# fixed sleep.
for _ in $(seq 1 30); do
[ -z "$(jobs -r)" ] && break
sleep 1
done
if [ "$(jobs -r)" ]
then

View File

@@ -8,6 +8,8 @@ export VAGRANT_CWD="$PWD/vagrant-$1"
mkdir -p logs
NETWORK="nebula-smoke"
cleanup() {
echo
echo " *** cleanup"
@@ -19,21 +21,51 @@ cleanup() {
docker kill lighthouse1 host2
fi
vagrant destroy -f
docker network rm "$NETWORK" >/dev/null 2>&1
}
trap cleanup EXIT
# Create a dedicated smoke network with an explicit subnet (required for --ip
# below). Probe a short list of candidates so a locally-used range doesn't
# fail the whole test — we only need one to be free.
docker network rm "$NETWORK" >/dev/null 2>&1 || true
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
break
fi
done
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
exit 1
fi
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
# .3 host2 — matches the placeholders in build.sh.
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
PREFIX="${SUBNET%/*}"
PREFIX="${PREFIX%.*}"
LIGHTHOUSE_IP="$PREFIX.2"
HOST2_IP="$PREFIX.3"
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
# This must happen before `vagrant up` rsyncs build/ into the VM for host3.
for f in build/host2.yml build/host3.yml; do
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
mv "$f.tmp" "$f"
done
CONTAINER="nebula:${NAME:-smoke}"
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test
vagrant up
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 15
@@ -96,7 +128,14 @@ vagrant ssh -c "ping -c1 192.168.100.2" -- -T
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
sleep 1
# Wait up to 30s for all backgrounded jobs to exit. vagrant ssh in particular
# takes a beat to tear down after nebula exits on the VM, so a fixed sleep is
# racy.
for _ in $(seq 1 30); do
[ -z "$(jobs -r)" ] && break
sleep 1
done
if [ "$(jobs -r)" ]
then

View File

@@ -6,6 +6,8 @@ set -o pipefail
mkdir -p logs
NETWORK="nebula-smoke"
cleanup() {
echo
echo " *** cleanup"
@@ -16,38 +18,71 @@ cleanup() {
then
docker kill lighthouse1 host2 host3 host4
fi
docker network rm "$NETWORK" >/dev/null 2>&1
}
trap cleanup EXIT
# Create a dedicated smoke network with an explicit subnet (required for --ip
# below). Probe a short list of candidates so a locally-used range doesn't
# fail the whole test — we only need one to be free.
docker network rm "$NETWORK" >/dev/null 2>&1 || true
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
break
fi
done
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
exit 1
fi
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
# .3 host2, .4 host3, .5 host4 — matches the placeholders in build.sh.
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
PREFIX="${SUBNET%/*}"
PREFIX="${PREFIX%.*}"
LIGHTHOUSE_IP="$PREFIX.2"
HOST2_IP="$PREFIX.3"
HOST3_IP="$PREFIX.4"
HOST4_IP="$PREFIX.5"
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
# build/lighthouse1.yml has no IPs to rewrite so it's skipped.
for f in build/host2.yml build/host3.yml build/host4.yml; do
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
mv "$f.tmp" "$f"
done
CONTAINER="nebula:${NAME:-smoke}"
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test
docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" "$CONTAINER" -config host3.yml -test
docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" "$CONTAINER" -config host4.yml -test
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 1
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
sleep 1
# grab tcpdump pcaps for debugging
docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
docker exec lighthouse1 tcpdump -i tun0 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
docker exec host2 tcpdump -i tun0 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
docker exec host3 tcpdump -i tun0 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
docker exec host4 tcpdump -i tun0 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
docker exec host2 ncat -nklv 0.0.0.0 2000 &
docker exec host3 ncat -nklv 0.0.0.0 2000 &
docker exec host4 ncat -e '/usr/bin/echo helloagainfromhost4' -nkluv 0.0.0.0 4000 &
docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
@@ -119,17 +154,24 @@ echo
echo " *** Testing conntrack"
echo
set -x
# host2 can ping host3 now that host3 pinged it first
docker exec host2 ping -c1 192.168.100.3
# host4 can ping host2 once conntrack established
docker exec host2 ping -c1 192.168.100.4
docker exec host4 ping -c1 192.168.100.2
# host4's outbound firewall only allows ICMP to the lighthouse, so host4
# cannot initiate UDP to host2. Once host2 initiates a flow to host4:4000,
# conntrack must let host4's listener reply on that flow. If it doesn't,
# the echo back from host4 never reaches host2.
docker exec host2 sh -c "(/usr/bin/echo host2; sleep 2) | ncat -nuv 192.168.100.4 4000" | grep -q helloagainfromhost4
docker exec host4 sh -c 'kill 1'
docker exec host3 sh -c 'kill 1'
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
sleep 5
# Wait up to 30s for all backgrounded jobs to exit rather than relying on a
# fixed sleep.
for _ in $(seq 1 30); do
[ -z "$(jobs -r)" ] && break
sleep 1
done
if [ "$(jobs -r)" ]
then

View File

@@ -1,7 +1,7 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
Vagrant.configure("2") do |config|
config.vm.box = "ubuntu/jammy64"
config.vm.box = "bento/ubuntu-24.04"
config.vm.synced_folder "../build", "/nebula"

View File

@@ -1,7 +1,7 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
Vagrant.configure("2") do |config|
config.vm.box = "generic/openbsd7"
config.vm.box = "DefinedNet/openbsd78"
config.vm.synced_folder "../build", "/nebula", type: "rsync"
end

View File

@@ -2,7 +2,21 @@ version: "2"
linters:
default: none
enable:
- sloglint
- testifylint
settings:
sloglint:
# Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and
# the package-level slog equivalents. Use l.Log(ctx, level, ...) for
# custom levels instead of LogAttrs when you can.
#
# LogAttrs is also flagged by this rule because it takes ...slog.Attr;
# the few legitimate sites (where attrs is built up as a []slog.Attr)
# carry a //nolint:sloglint with rationale.
kv-only: true
# no-mixed-args is on by default: forbids mixing kv and attrs in one call.
# discard-handler is on by default (since Go 1.24): suggests
# slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil).
exclusions:
generated: lax
presets:

View File

@@ -7,6 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [1.10.3] - 2026-02-06
### Security
- Fix an issue where blocklist bypass is possible when using curve P256 since the signature can have 2 valid representations.
Both fingerprint representations will be tested against the blocklist.
Any newly issued P256 based certificates will have their signature clamped to the low-s form.
Nebula will assert the low-s signature form when validating certificates in a future version. [GHSA-69x3-g4r3-p962](https://github.com/slackhq/nebula/security/advisories/GHSA-69x3-g4r3-p962)
### Changed
- Improve error reporting if nebula fails to start due to a tun device naming issue. (#1588)
## [1.10.2] - 2026-01-21
### Fixed
- Fix panic when using `use_system_route_table` that was introduced in v1.10.1. (#1580)
### Changed
- Fix some typos in comments. (#1582)
- Dependency updates. (#1581)
## [1.10.1] - 2026-01-16
See the [v1.10.1](https://github.com/slackhq/nebula/milestone/26?closed=1) milestone for a complete list of changes.
@@ -764,7 +788,9 @@ created.)
- Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.1...HEAD
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.3...HEAD
[1.10.3]: https://github.com/slackhq/nebula/releases/tag/v1.10.3
[1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2
[1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7

1
CODEOWNERS Normal file
View File

@@ -0,0 +1 @@
#ECCN:Open Source

View File

@@ -57,7 +57,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
docker pull nebulaoss/nebula
```
#### Mobile
#### Mobile ([source code](https://github.com/DefinedNet/mobile_nebula))
- [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&amp;itscg=30200)
- [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1)
@@ -76,6 +76,8 @@ Nebula was created to provide a mechanism for groups of hosts to communicate sec
## Getting started (quickly)
**Don't want to manage your own PKI and lighthouses?** [Managed Nebula](https://www.defined.net/) from Defined Networking handles all of this for you.
To set up a Nebula network, you'll need:
#### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use.

239
bits.go
View File

@@ -1,23 +1,43 @@
package nebula
import (
"context"
"fmt"
"log/slog"
"math"
mathbits "math/bits"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
)
const bitsPerWord = 64
// Bits is a sliding-window anti-replay tracker. The window is stored as a
// circular bitmap packed into uint64 words (8x denser than a []bool), so a
// length-N window costs N/8 bytes. length must be a power of two.
type Bits struct {
length uint64
lengthMask uint64
current uint64
bits []bool
bits []uint64
lostCounter metrics.Counter
dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter
}
func NewBits(bits uint64) *Bits {
func NewBits(length uint64) *Bits {
if length == 0 || length&(length-1) != 0 {
panic(fmt.Sprintf("Bits length must be a power of two, got %d", length))
}
nWords := length / bitsPerWord
if nWords == 0 {
nWords = 1
}
b := &Bits{
length: bits,
bits: make([]bool, bits, bits),
length: length,
lengthMask: length - 1,
bits: make([]uint64, nWords),
current: 0,
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
@@ -25,88 +45,219 @@ func NewBits(bits uint64) *Bits {
}
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
b.bits[0] = 1
return b
}
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
func (b *Bits) get(i uint64) bool {
pos := i & b.lengthMask
//bit-shifting by 6 because i is a bit index, not a u64 index, and we need to find the u64 without bit in it
return b.bits[pos>>6]&(uint64(1)<<(pos&63)) != 0
}
func (b *Bits) set(i uint64) {
pos := i & b.lengthMask
b.bits[pos>>6] |= uint64(1) << (pos & 63)
}
// clearRange clears `count` bits starting at circular position `startPos`
// (already masked to [0, length)) and returns how many of them were set
// before the clear. count must be in [1, length].
func (b *Bits) clearRange(startPos, count uint64) uint64 {
wasSet := uint64(0)
if count >= b.length {
for _, w := range b.bits {
wasSet += uint64(mathbits.OnesCount64(w))
}
clear(b.bits)
return wasSet
}
pos := startPos
remaining := count
// handle the potential partial word before pos becomes u64 aligned
word := pos >> 6
bit := pos & 63
take := uint64(64) - bit
if take > remaining {
take = remaining
}
if take > b.length-pos {
take = b.length - pos
}
var mask uint64
if take == 64 {
mask = math.MaxUint64
} else {
mask = ((uint64(1) << take) - 1) << bit
}
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
b.bits[word] &^= mask
remaining -= take
pos = (pos + take) & b.lengthMask
// Clear whole words, keeping track of the number of set bits
for remaining >= 64 {
word = pos >> 6
wasSet += uint64(mathbits.OnesCount64(b.bits[word]))
b.bits[word] = 0
remaining -= 64
pos = (pos + 64) & b.lengthMask
}
// Clear the remaining partial word
if remaining > 0 {
word = pos >> 6
mask = (uint64(1) << remaining) - 1
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
b.bits[word] &^= mask
}
return wasSet
}
func (b *Bits) strictlyWithinWindow(i uint64) bool {
// Handle the case where the window hasn't slid yet. This avoids u64 underflow.
inWarmup := b.current < b.length
if i < b.length && inWarmup {
return true
}
// Next, if the packet is in-window, see if we've seen it before
if i > b.current-b.length {
return true
}
return false //not within window!
}
// Check returns true if i is within (or way out in front of) the window, and not a replay
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
// If i is the next number, return true.
if i > b.current {
return true
}
// If i is within the window, check if it's been set already.
if i > b.current-b.length || i < b.length && b.current < b.length {
return !b.bits[i%b.length]
if b.strictlyWithinWindow(i) {
return !b.get(i)
}
// Not within the window
if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("rejected a packet (top)", "current", b.current, "incoming", i)
}
return false
}
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current.
// Update has three branches:
// - i == b.current+1: fast path; advance the cursor by one and lose-count
// the slot we just stomped (only past warmup; see the i > b.length guard
// below).
// - i > b.current+1: jump path; clear all slots between current and i
// (or up to a full window's worth, whichever is smaller) via clearRange,
// then mark i. Two arms here: a warmup arm that handles the very first
// window before the cursor has slid, and a steady-state arm that treats
// every cleared empty slot as a lost packet.
// - i <= b.current: in-window check for duplicates; out-of-window otherwise.
//
// NewBits seeds bits[0]=1 so counter 0 looks "received" — Update never
// clears that marker during warmup (clearRange skips position 0 when
// startPos=1), and once b.current >= b.length the marker is no longer
// consulted. The marker prevents a fictitious "lost" hit on the first real
// counter.
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
// Fast path: i is the next expected counter. Split out so the function
// stays small and avoids paying for the slow paths' slog argument-build
// stack frame on every call. The bit read/test/write is inlined to
// touch the backing word once.
if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
// The very first window can only be tracked as lost once we are on the 2nd window or greater
if b.bits[i%b.length] == false && i > b.length {
pos := i & b.lengthMask
word := pos >> 6
mask := uint64(1) << (pos & 63)
w := b.bits[word]
if i > b.length && w&mask == 0 {
b.lostCounter.Inc(1)
}
b.bits[i%b.length] = true
b.bits[word] = w | mask
b.current = i
return true
}
return b.updateSlow(l, i)
}
// updateSlow handles jumps, in-window backfill, dupes, and out-of-window.
func (b *Bits) updateSlow(l *slog.Logger, i uint64) bool {
// If i is a jump, adjust the window, record lost, update current, and return true
if i > b.current {
lost := int64(0)
// Zero out the bits between the current and the new counter value, limited by the window size,
// since the window is shifting
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
lost++
end := i
if end > b.current+b.length {
end = b.current + b.length
}
count := end - b.current
startPos := (b.current + 1) & b.lengthMask
var lost int64
if b.current >= b.length {
// Steady state: every cleared slot is past warmup, so any unset
// bit we evict is a lost packet from the previous cycle.
wasSet := b.clearRange(startPos, count)
lost = int64(count) - int64(wasSet)
} else {
// Warmup (the very first window). Some cleared slots represent
// packets <= length where eviction is not "lost" in the usual
// sense. This branch is taken at most once per connection so we
// don't bother optimizing it.
for n := b.current + 1; n <= end; n++ {
if !b.get(n) && n > b.length {
lost++
}
}
b.bits[n%b.length] = false
b.clearRange(startPos, count)
}
// Only record any skipped packets as a result of the window moving further than the window length
// Any loss within the new window will be accounted for in future calls
lost += max(0, int64(i-b.current-b.length))
// Anything past the new window can never be backfilled, so it's lost.
if i > b.current+b.length {
lost += int64(i - b.current - b.length)
}
b.lostCounter.Inc(lost)
b.bits[i%b.length] = true
b.set(i)
b.current = i
return true
}
// If i is within the current window but below the current counter,
// Check to see if it's a duplicate
if i > b.current-b.length || i < b.length && b.current < b.length {
if b.current == i || b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window")
// If i is within the current window but below the current counter, check to see if it's a duplicate
if b.strictlyWithinWindow(i) {
pos := i & b.lengthMask
word := pos >> 6
mask := uint64(1) << (pos & 63)
w := b.bits[word]
if b.current == i || w&mask != 0 {
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window",
"accepted", false,
"currentCounter", b.current,
"incomingCounter", i,
"reason", "duplicate",
)
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true
b.bits[word] = w | mask
return true
}
// In all other cases, fail and don't change current.
b.outOfWindowCounter.Inc(1)
if l.Level >= logrus.DebugLevel {
l.WithField("accepted", false).
WithField("currentCounter", b.current).
WithField("incomingCounter", i).
WithField("reason", "nonsense").
Debug("Receive window")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window",
"accepted", false,
"currentCounter", b.current,
"incomingCounter", i,
"reason", "nonsense",
)
}
return false
}

View File

@@ -7,61 +7,79 @@ import (
"github.com/stretchr/testify/assert"
)
// snapshot returns the bitmap as a []bool of length b.length, for readable
// test assertions against the now-packed []uint64 storage.
func (b *Bits) snapshot() []bool {
out := make([]bool, b.length)
for i := uint64(0); i < b.length; i++ {
out[i] = b.get(i)
}
return out
}
func TestBitsRequiresPowerOfTwo(t *testing.T) {
assert.Panics(t, func() { NewBits(10) })
assert.Panics(t, func() { NewBits(0) })
assert.NotPanics(t, func() { NewBits(1) })
assert.NotPanics(t, func() { NewBits(16) })
assert.NotPanics(t, func() { NewBits(1024) })
assert.NotPanics(t, func() { NewBits(16384) })
}
func TestBits(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
// make sure it is the right size
assert.Len(t, b.bits, 10)
b := NewBits(16)
assert.EqualValues(t, 16, b.length)
// This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1))
assert.True(t, b.Update(l, 1))
assert.EqualValues(t, 1, b.current)
g := []bool{true, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
g := []bool{true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// Receive two
assert.True(t, b.Check(l, 2))
assert.True(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
g = []bool{true, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
g = []bool{true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// Receive two again - it will fail
assert.False(t, b.Check(l, 2))
assert.False(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15))
assert.True(t, b.Update(l, 15))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Jump ahead to 25, which clears the window and sets slot 25%16 = 9.
assert.True(t, b.Check(l, 25))
assert.True(t, b.Update(l, 25))
assert.EqualValues(t, 25, b.current)
g = []bool{false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14))
assert.True(t, b.Update(l, 14))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 24, which is in window (current 25, length 16, window covers [10,25]).
assert.True(t, b.Check(l, 24))
assert.True(t, b.Update(l, 24))
assert.EqualValues(t, 25, b.current)
g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// Mark 5, which is not allowed because it is not in the window
// Mark 5, not allowed because 5 <= current-length (25-16=9).
assert.False(t, b.Check(l, 5))
assert.False(t, b.Update(l, 5))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
assert.EqualValues(t, 25, b.current)
g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false}
assert.Equal(t, g, b.snapshot())
// make sure we handle wrapping around once to the current position
b = NewBits(10)
// Make sure we handle wrapping around once to the same slot. With
// length=16, packets 1 and 17 share slot 1.
b = NewBits(16)
assert.True(t, b.Update(l, 1))
assert.True(t, b.Update(l, 11))
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
assert.True(t, b.Update(l, 17))
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}, b.snapshot())
// Walk through a few windows in order
b = NewBits(10)
b = NewBits(16)
for i := uint64(1); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i)
@@ -72,24 +90,31 @@ func TestBits(t *testing.T) {
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
// length=16. Update(55) from current=0:
// warmup, per-bit loop sees no n>16 with unset bits (slot 0 was set by
// NewBits and gets re-evaluated when n=16; n=16 is not strictly > 16),
// so the loop contributes 0. The jump exceeds the window so we record
// 55 - 0 - 16 = 39 packets fell out the back.
b := NewBits(16)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55))
assert.Equal(t, int64(39), b.lostCounter.Count())
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
// Update(100): clears 16 slots starting at slot 56%16=8. Only slot 7 (for
// packet 55) was set, so 16 - 1 = 15 evicted slots had unset bits.
// Plus 100 - 55 - 16 = 29 packets fell past the window. Total 44.
assert.True(t, b.Update(l, 100))
assert.Equal(t, int64(39+44), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
// Update(200): same shape: 16 - 1 = 15 evicted unset, plus 200 - 100 - 16 = 84 past window. Total 99.
assert.True(t, b.Update(l, 200))
assert.Equal(t, int64(39+44+99), b.lostCounter.Count())
}
func TestBitsDupeCounter(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b := NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
@@ -114,120 +139,117 @@ func TestBitsDupeCounter(t *testing.T) {
func TestBitsOutOfWindowCounter(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b := NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
// Jump to 20 (warmup branch + 4 past-window packets).
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
// 9 single-step advances, each evicts a slot whose bit was cleared during
// the jump above and whose value was never seen, so each contributes 1
// to lostCounter.
for n := uint64(21); n <= 29; n++ {
assert.True(t, b.Update(l, n))
}
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
// 0 is below current-length (29-16=13) so it falls outside the window.
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
// 4 from the Update(20) jump + 9 from 21..29.
assert.Equal(t, int64(13), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
}
func TestBitsLostCounter(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b := NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
// Walk 20..29 like the original, just with a bigger window. Same
// reasoning as TestBitsOutOfWindowCounter: 4 past-window from Update(20),
// then 9 more from the unit advances.
for n := uint64(20); n <= 29; n++ {
assert.True(t, b.Update(l, n))
}
assert.Equal(t, int64(13), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
b = NewBits(10)
b = NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(1), b.lostCounter.Count())
// Now let's fill in the window, should end up with 8 lost packets
assert.True(t, b.Update(l, 12))
assert.True(t, b.Update(l, 13))
assert.True(t, b.Update(l, 14))
// Update(15) clears the warmup window (no lost), sets slot 15.
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Update(16): slot 0 was already set (NewBits seeded it), and 16 is not
// strictly > length, so nothing is recorded as lost.
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Update(17): we jumped straight from 0 to 15, so slot 1 was cleared
// (and never re-set). 17 > 16 is past warmup, so packet 1 is recorded lost.
assert.True(t, b.Update(l, 17))
assert.True(t, b.Update(l, 18))
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(8), b.lostCounter.Count())
assert.Equal(t, int64(1), b.lostCounter.Count())
// Jump ahead by a window size
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(8), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in
assert.True(t, b.Update(l, 30))
assert.True(t, b.Update(l, 31))
assert.True(t, b.Update(l, 32))
assert.True(t, b.Update(l, 33))
assert.True(t, b.Update(l, 34))
assert.True(t, b.Update(l, 35))
assert.True(t, b.Update(l, 36))
assert.True(t, b.Update(l, 37))
assert.True(t, b.Update(l, 38))
// 39 packets tracked, 22 seen, 17 lost
assert.Equal(t, int64(17), b.lostCounter.Count())
// Fill in 18..30 in single steps. Each i evicts slot i%16. Slots 2..14
// were all cleared during Update(15), and we never re-set any of them,
// so each i in 18..30 is a fresh lost packet — 13 more.
for n := uint64(18); n <= 30; n++ {
assert.True(t, b.Update(l, n))
}
assert.Equal(t, int64(14), b.lostCounter.Count())
// Jump ahead by 2 windows, should have recording 1 full window missing
assert.True(t, b.Update(l, 58))
assert.Equal(t, int64(27), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in from this window
assert.True(t, b.Update(l, 59))
assert.True(t, b.Update(l, 60))
assert.True(t, b.Update(l, 61))
assert.True(t, b.Update(l, 62))
assert.True(t, b.Update(l, 63))
assert.True(t, b.Update(l, 64))
assert.True(t, b.Update(l, 65))
assert.True(t, b.Update(l, 66))
assert.True(t, b.Update(l, 67))
// 68 packets tracked, 32 seen, 36 missed
assert.Equal(t, int64(36), b.lostCounter.Count())
// Jump ahead by exactly one window size.
assert.True(t, b.Update(l, 46))
// end = min(46, 30+16) = 46, count = 16, all slots cleared. Before the
// jump every slot 0..15 had been set (Update(15), (16), (17), 18..30),
// so wasSet=16 and 46 == current+length means no past-window slack:
// lost contribution = 0.
assert.Equal(t, int64(14), b.lostCounter.Count())
// Walk 47..55. The Update(46) jump cleared every slot, so only slot 14
// (for packet 46) is set when we start. Each subsequent unit step lands
// on a slot that was cleared and is past warmup, so it counts as lost.
// 9 more = 23.
for n := uint64(47); n <= 55; n++ {
assert.True(t, b.Update(l, n))
}
assert.Equal(t, int64(23), b.lostCounter.Count())
// Jump ahead by two windows: clears the window plus past-window loss.
assert.True(t, b.Update(l, 87))
// current=55, length=16. end = min(87, 71) = 71. count=16, all slots
// cleared. Slots set before the clear are slots 14,15,0..7 (10 total).
// Lost from clear = 16 - 10 = 6. Past window: 87 - 55 - 16 = 16. +22.
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b := NewBits(16)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
// Receive 4, backfill 1, then 9, 2, 3, 5, 6, 7 (skip 8), 10, 11, 14.
// Then jump to 25 — slot 25%16=9 is being evicted, but it had been set
// (we received packet 9), so no spurious lost increment. The original
// regression was about double-counting a missing packet when its slot
// got cleared on a jump. With the jump path now using clearRange's
// word-level wasSet count, the same semantics hold.
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
@@ -244,7 +266,7 @@ func TestBitsLostCounterIssue1(t *testing.T) {
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
// Skip packet 8.
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
@@ -252,9 +274,23 @@ func TestBitsLostCounterIssue1(t *testing.T) {
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
// Jump to 25. With length=16, slot 25%16=9 corresponds to packet 9
// (which we DID receive), so its bit is set and no lost++ from that
// eviction. The trace below shows the only loss is packet 8.
assert.True(t, b.Update(l, 25))
// current was 14, i=25. end=min(25,30)=25. count=11. startPos=15.
// steady? current=14<16, so warmup branch: per-bit n=15..25, count those
// with !get(n) AND n>16. n=17..25 are >16. Among slots 17%16=1..25%16=9
// did we set slots 1..9 (packets 1..9)? Yes for all but slot 8 (packet 8
// was skipped). n=24 maps to slot 8 which is FALSE → lost++. All other
// n in 17..25 map to slots that are set. n=16 is not strictly > 16. So
// lost = 1.
assert.Equal(t, int64(1), b.lostCounter.Count())
// Fill in 12, 13, 15, 16. Each is below current=25 (in-window). 16 must
// recheck slot 0 — it was set by NewBits and then cleared by the
// Update(25) jump, so 16 backfills cleanly.
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
@@ -263,29 +299,140 @@ func TestBitsLostCounterIssue1(t *testing.T) {
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
// We missed packet 8 above and that loss is still recorded once, never
// double-counted, never zeroed.
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) {
z := NewBits(10)
for n := 0; n < b.N; n++ {
for i := range z.bits {
z.bits[i] = true
}
for i := range z.bits {
z.bits[i] = false
}
// TestBitsWarmupOvershoot exercises the jump path's warmup arm with an
// overshoot past one full window. NewBits leaves current=0 with only slot 0
// "set" by the marker. Jumping straight to length+k must (a) clear every
// slot the jump straddles, (b) count only past-window slack (not the
// in-window slots, which never had a "lost" tenant during warmup), and
// (c) leave the cursor at the new counter so subsequent unit advances
// count from steady state. The marker bit at slot 0 is irrelevant once
// current >= length.
func TestBitsWarmupOvershoot(t *testing.T) {
l := test.NewLogger()
b := NewBits(16)
b.lostCounter.Clear()
// Jump from current=0 to i=20 (length=16, overshoot=4).
// Warmup arm: counts slots in [1..16] where bit unset and n>length.
// Only n=16 was unset and >length: but slot 16%16=0 is the marker,
// so b.get(16) reads bits[0]=1 and skips. Result: 0 lost from the loop.
// Past-window: i - current - length = 20 - 0 - 16 = 4 lost.
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(4), b.lostCounter.Count())
assert.Equal(t, uint64(20), b.current)
// Steady state now (current=20 >= length=16). Unit advance to 21
// stomps slot 21%16=5, which was cleared by the jump and not reset,
// so this is +1 lost.
assert.True(t, b.Update(l, 21))
assert.Equal(t, int64(5), b.lostCounter.Count())
}
// TestBitsCheckAcrossWarmupBoundary pins the underflow trick in Check's
// in-window clause. While in warmup, b.current-b.length underflows uint64
// to a huge value so the first OR-clause is always false; the second
// clause (i < length && current < length) carries the in-window check.
// Once current >= length the regimes flip cleanly.
func TestBitsCheckAcrossWarmupBoundary(t *testing.T) {
l := test.NewLogger()
b := NewBits(16)
// Warmup: current=0. Check(0) must read the marker (set) and return false.
assert.False(t, b.Check(l, 0), "marker slot should look already-received")
// Warmup: any 0 < i < length is in-window and unset → accepted.
for i := uint64(1); i < 16; i++ {
assert.True(t, b.Check(l, i), "warmup in-window i=%d should be accepted", i)
}
// Warmup: i >= length but > current is "next number" so accepted.
assert.True(t, b.Check(l, 16))
assert.True(t, b.Check(l, 1_000_000))
// Cross into steady state.
assert.True(t, b.Update(l, 100))
// Now current=100, length=16. In-window range is [85..100].
// 84 is just outside: the underflow clause activates; 84 > 100-16=84 is false.
// And the warmup clause is false (current >= length). So out of window.
assert.False(t, b.Check(l, 84))
// 85 sits at the boundary. 85 > 84 is true → in window, unset → accept.
assert.True(t, b.Check(l, 85))
// 100 is current itself; not strictly greater, in-window, but already set.
assert.False(t, b.Check(l, 100))
// Way out: clearly out of window.
assert.False(t, b.Check(l, 50))
}
// TestBitsMarkerInvariant verifies the seeded bits[0]=1 marker behaves
// correctly across warmup and beyond. Update should never clear the marker
// during warmup (clearRange skips position 0 when startPos=1), and once
// current >= length the marker is no longer consulted by Check/Update on
// the live path — but it must still report counter 0 as a duplicate while
// we are in warmup.
func TestBitsMarkerInvariant(t *testing.T) {
l := test.NewLogger()
b := NewBits(8)
// Counter 0 is the seeded marker; Check sees it as already received.
assert.False(t, b.Check(l, 0))
// Update(0) at current=0 hits the duplicate branch.
b.dupeCounter.Clear()
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.dupeCounter.Count())
// Walk forward through warmup; the marker must remain set.
for n := uint64(1); n <= 7; n++ {
assert.True(t, b.Update(l, n))
}
// Position 0 (the marker) should still read as set because we never
// cleared it; Update(0) still looks like a duplicate.
assert.False(t, b.Check(l, 0))
// Cross into steady state with a unit advance to 8: pos=0, evicts the
// marker bit. The lost-counter guard (i > b.length) is false (8 == 8),
// so this advance does NOT charge a lost packet — exactly what the
// marker is there to prevent.
b.lostCounter.Clear()
assert.True(t, b.Update(l, 8))
assert.Equal(t, int64(0), b.lostCounter.Count())
// The slot at pos 0 is now occupied by counter 8.
assert.False(t, b.Check(l, 8))
}
// BenchmarkBitsUpdateInOrder is the steady-state hot path: each call is
// i == current+1.
func BenchmarkBitsUpdateInOrder(b *testing.B) {
l := test.NewLogger()
z := NewBits(16384)
for n := 0; n < b.N; n++ {
z.Update(l, uint64(n)+1)
}
}
// BenchmarkBitsUpdateReorder simulates light reorder within the window:
// every other packet arrives one slot behind its predecessor (forces the
// in-window backfill branch).
func BenchmarkBitsUpdateReorder(b *testing.B) {
l := test.NewLogger()
z := NewBits(16384)
for n := 0; n < b.N; n++ {
base := uint64(n) * 2
z.Update(l, base+2)
z.Update(l, base+1)
}
}
// BenchmarkBitsUpdateLargeJumps stresses the clearRange word-level path.
func BenchmarkBitsUpdateLargeJumps(b *testing.B) {
l := test.NewLogger()
z := NewBits(16384)
for n := 0; n < b.N; n++ {
z.Update(l, uint64(n+1)*1000)
}
}

View File

@@ -1,5 +1,4 @@
//go:build boringcrypto
// +build boringcrypto
package nebula

View File

@@ -1,11 +1,14 @@
package cert
import (
"bufio"
"bytes"
"encoding/pem"
"errors"
"fmt"
"io"
"net/netip"
"slices"
"strings"
"time"
)
@@ -29,22 +32,46 @@ func NewCAPool() *CAPool {
// If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
}
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
// The reader must contain a PEM-encoded set of nebula certificates.
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
pool := NewCAPool()
var err error
var expired bool
for {
caPEMs, err = pool.AddCAFromPEM(caPEMs)
if errors.Is(err, ErrExpired) {
expired = true
err = nil
scanner := bufio.NewScanner(r)
scanner.Split(SplitPEM)
for scanner.Scan() {
pemBytes := scanner.Bytes()
block, rest := pem.Decode(pemBytes)
if len(bytes.TrimSpace(rest)) > 0 {
return nil, ErrInvalidPEMBlock
}
if block == nil {
return nil, ErrInvalidPEMBlock
}
c, err := unmarshalCertificateBlock(block)
if err != nil {
return nil, err
}
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break
err = pool.AddCA(c)
if errors.Is(err, ErrExpired) {
expired = true
continue
} else if err != nil {
return nil, err
}
}
if err := scanner.Err(); err != nil {
return nil, ErrInvalidPEMBlock
}
if expired {
return pool, ErrExpired
@@ -141,10 +168,23 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti
return nil, err
}
// Pre nebula v1.10.3 could generate signatures in either high or low s form and validation
// of signatures allowed for either. Nebula v1.10.3 and beyond clamps signature generation to low-s form
// but validation still allows for either. Since a change in the signature bytes affects the fingerprint, we
// need to test both forms until such a time comes that we enforce low-s form on signature validation.
fp2, err := CalculateAlternateFingerprint(c)
if err != nil {
return nil, fmt.Errorf("could not calculate alternate fingerprint to verify: %w", err)
}
if fp2 != "" && ncp.IsBlocklisted(fp2) {
return nil, ErrBlockListed
}
cc := CachedCertificate{
Certificate: c,
InvertedGroups: make(map[string]struct{}),
Fingerprint: fp,
fingerprint2: fp2,
signerFingerprint: signer.Fingerprint,
}
@@ -158,6 +198,11 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti
// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
// is a cheaper operation to perform as a result.
func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
// Check any available alternate fingerprint forms for this certificate, re P256 high-s/low-s
if c.fingerprint2 != "" && ncp.IsBlocklisted(c.fingerprint2) {
return ErrBlockListed
}
_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
return err
}

View File

@@ -1,10 +1,14 @@
package cert
import (
"bytes"
"io"
"net/netip"
"strings"
"testing"
"time"
"github.com/slackhq/nebula/cert/p256"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -111,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
assert.Len(t, ppppp.CAs, 1)
}
// oneByteReader wraps a reader to return at most 1 byte per Read call,
// exercising the streaming accumulation logic in NewCAPoolFromPEMReader.
type oneByteReader struct {
r io.Reader
}
func (o *oneByteReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
return o.r.Read(p[:1])
}
func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) {
pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil))
require.NoError(t, err)
assert.Empty(t, pool.CAs)
pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n "))
require.NoError(t, err)
assert.Empty(t, pool.CAs)
}
func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) {
ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
bundle := append(pem1, pem2...)
pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)})
require.NoError(t, err)
assert.Len(t, pool.CAs, 2)
fp1, err := ca1.Fingerprint()
require.NoError(t, err)
fp2, err := ca2.Fingerprint()
require.NoError(t, err)
assert.Contains(t, pool.CAs, fp1)
assert.Contains(t, pool.CAs, fp2)
}
func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) {
_, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata"))
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
}
func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) {
_, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
bundle := append(pem1, []byte("some trailing garbage")...)
_, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle))
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
}
func TestCertificateV1_Verify(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
@@ -170,6 +228,15 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
_, err = caPool.VerifyCertificate(time.Now(), c)
require.EqualError(t, err, "certificate is in the block list")
// Create a copy of the cert and swap to the alternate form for the signature
nc := c.Copy()
b, err := p256.Swap(c.Signature())
require.NoError(t, err)
require.NoError(t, nc.(*certificateV1).setSignature(b))
_, err = caPool.VerifyCertificate(time.Now(), nc)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
@@ -187,7 +254,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
require.NoError(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
b, err = caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
@@ -196,7 +263,17 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
})
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c)
cc, err := caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Reset the blocklist and block the alternate form fingerprint
caPool.ResetCertBlocklist()
caPool.BlocklistFingerprint(cc.fingerprint2)
err = caPool.VerifyCachedCertificate(time.Now(), cc)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
err = caPool.VerifyCachedCertificate(time.Now(), cc)
require.NoError(t, err)
}
@@ -394,6 +471,15 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
_, err = caPool.VerifyCertificate(time.Now(), c)
require.EqualError(t, err, "certificate is in the block list")
// Create a copy of the cert and swap to the alternate form for the signature
nc := c.Copy()
b, err := p256.Swap(c.Signature())
require.NoError(t, err)
require.NoError(t, nc.(*certificateV2).setSignature(b))
_, err = caPool.VerifyCertificate(time.Now(), nc)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
@@ -411,7 +497,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
require.NoError(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
b, err = caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
@@ -420,7 +506,17 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
})
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c)
cc, err := caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Reset the blocklist and block the alternate form fingerprint
caPool.ResetCertBlocklist()
caPool.BlocklistFingerprint(cc.fingerprint2)
err = caPool.VerifyCachedCertificate(time.Now(), cc)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
err = caPool.VerifyCachedCertificate(time.Now(), cc)
require.NoError(t, err)
}

View File

@@ -4,6 +4,8 @@ import (
"fmt"
"net/netip"
"time"
"github.com/slackhq/nebula/cert/p256"
)
type Version uint8
@@ -110,6 +112,9 @@ type CachedCertificate struct {
InvertedGroups map[string]struct{}
Fingerprint string
signerFingerprint string
// A place to store a 2nd fingerprint if the certificate could have one, such as with P256
fingerprint2 string
}
func (cc *CachedCertificate) String() string {
@@ -152,3 +157,31 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
return c, nil
}
// CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates
// CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step.
func CalculateAlternateFingerprint(c Certificate) (string, error) {
if c.Curve() != Curve_P256 {
return "", nil
}
nc := c.Copy()
b, err := p256.Swap(nc.Signature())
if err != nil {
return "", err
}
switch v := nc.(type) {
case *certificateV1:
err = v.setSignature(b)
case *certificateV2:
err = v.setSignature(b)
default:
return "", ErrUnknownVersion
}
if err != nil {
return "", err
}
return nc.Fingerprint()
}

127
cert/p256/p256.go Normal file
View File

@@ -0,0 +1,127 @@
package p256
import (
"crypto/elliptic"
"errors"
"math/big"
"filippo.io/bigmod"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
var halfN = new(big.Int).Rsh(elliptic.P256().Params().N, 1)
var nMod *bigmod.Modulus
func init() {
n, err := bigmod.NewModulus(elliptic.P256().Params().N.Bytes())
if err != nil {
panic(err)
}
nMod = n
}
func IsNormalized(sig []byte) (bool, error) {
r, s, err := parseSignature(sig)
if err != nil {
return false, err
}
return checkLowS(r, s), nil
}
func checkLowS(_, s []byte) bool {
bigS := new(big.Int).SetBytes(s)
// Check if S <= (N/2), because we want to include the midpoint in the set of low-s
return bigS.Cmp(halfN) <= 0
}
func swap(r, s []byte) ([]byte, []byte, error) {
var err error
bigS, err := bigmod.NewNat().SetBytes(s, nMod)
if err != nil {
return nil, nil, err
}
sNormalized := nMod.Nat().Sub(bigS, nMod)
result := sNormalized.Bytes(nMod)
for len(result) > 1 && result[0] == 0 {
result = result[1:]
}
return r, result, nil
}
func Normalize(sig []byte) ([]byte, error) {
r, s, err := parseSignature(sig)
if err != nil {
return nil, err
}
if checkLowS(r, s) {
return sig, nil
}
newR, newS, err := swap(r, s)
if err != nil {
return nil, err
}
return encodeSignature(newR, newS)
}
// Swap will change sig between its current form to the opposite high or low form.
func Swap(sig []byte) ([]byte, error) {
r, s, err := parseSignature(sig)
if err != nil {
return nil, err
}
newR, newS, err := swap(r, s)
if err != nil {
return nil, err
}
return encodeSignature(newR, newS)
}
// parseSignature taken exactly from crypto/ecdsa/ecdsa.go
func parseSignature(sig []byte) (r, s []byte, err error) {
var inner cryptobyte.String
input := cryptobyte.String(sig)
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
!input.Empty() ||
!inner.ReadASN1Integer(&r) ||
!inner.ReadASN1Integer(&s) ||
!inner.Empty() {
return nil, nil, errors.New("invalid ASN.1")
}
return r, s, nil
}
func encodeSignature(r, s []byte) ([]byte, error) {
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
addASN1IntBytes(b, r)
addASN1IntBytes(b, s)
})
return b.Bytes()
}
// addASN1IntBytes encodes in ASN.1 a positive integer represented as
// a big-endian byte slice with zero or more leading zeroes.
func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
for len(bytes) > 0 && bytes[0] == 0 {
bytes = bytes[1:]
}
if len(bytes) == 0 {
b.SetError(errors.New("invalid integer"))
return
}
b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) {
if bytes[0]&0x80 != 0 {
c.AddUint8(0)
}
c.AddBytes(bytes)
})
}

28
cert/p256/p256_test.go Normal file
View File

@@ -0,0 +1,28 @@
package p256
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"testing"
"github.com/stretchr/testify/require"
)
func TestFlipping(t *testing.T) {
priv, err1 := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err1)
out, err := ecdsa.SignASN1(rand.Reader, priv, []byte("big chungus"))
require.NoError(t, err)
r, s, err := parseSignature(out)
require.NoError(t, err)
r, s1, err := swap(r, s)
require.NoError(t, err)
r, s2, err := swap(r, s1)
require.NoError(t, err)
require.Equal(t, s, s2)
require.NotEqual(t, s, s1)
}

View File

@@ -1,12 +1,66 @@
package cert
import (
"bytes"
"encoding/pem"
"errors"
"fmt"
"golang.org/x/crypto/ed25519"
)
var ErrTruncatedPEMBlock = errors.New("truncated PEM block")
// SplitPEM is a split function for bufio.Scanner that returns each PEM block.
func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) {
// Look for the start of a PEM block
start := bytes.Index(data, []byte("-----BEGIN "))
if start == -1 {
if atEOF && len(bytes.TrimSpace(data)) > 0 {
// Non-whitespace content with no PEM block
return 0, nil, ErrTruncatedPEMBlock
}
if atEOF {
return len(data), nil, nil
}
// Request more data
return 0, nil, nil
}
// Look for the end marker
endMarkerStart := bytes.Index(data[start:], []byte("-----END "))
if endMarkerStart == -1 {
if atEOF {
// Incomplete PEM block at EOF
return 0, nil, ErrTruncatedPEMBlock
}
// Need more data to find the end
return 0, nil, nil
}
// Find the actual end of the END line (after the newline)
endMarkerStart += start
endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n')
var end int
if endLineEnd == -1 {
if atEOF {
// END marker without newline at EOF - take it anyway
end = len(data)
} else {
// Need more data
return 0, nil, nil
}
} else {
end = endMarkerStart + endLineEnd + 1
}
// Extract the PEM block
pemBlock := data[start:end]
// Return the valid PEM block
return end, pemBlock, nil
}
const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
@@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
return nil, r, ErrInvalidPEMBlock
}
var c Certificate
var err error
switch p.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
c, err := unmarshalCertificateBlock(p)
if err != nil {
return nil, r, err
}
@@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
switch block.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
return unmarshalCertificateV1(block.Bytes, nil)
case CertificateV2Banner:
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
default:
return nil, ErrInvalidPEMCertificateBanner
}
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())

View File

@@ -1,12 +1,88 @@
package cert
import (
"bufio"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func scanAll(t *testing.T, input string) ([]string, error) {
t.Helper()
scanner := bufio.NewScanner(strings.NewReader(input))
scanner.Split(SplitPEM)
var blocks []string
for scanner.Scan() {
blocks = append(blocks, scanner.Text())
}
return blocks, scanner.Err()
}
func TestSplitPEM_Single(t *testing.T) {
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n"
blocks, err := scanAll(t, input)
require.NoError(t, err)
require.Len(t, blocks, 1)
require.Equal(t, input, blocks[0])
}
func TestSplitPEM_Multiple(t *testing.T) {
block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n"
block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
blocks, err := scanAll(t, block1+block2)
require.NoError(t, err)
require.Len(t, blocks, 2)
require.Equal(t, block1, blocks[0])
require.Equal(t, block2, blocks[1])
}
func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) {
input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
blocks, err := scanAll(t, input)
require.NoError(t, err)
require.Len(t, blocks, 2)
}
func TestSplitPEM_Empty(t *testing.T) {
blocks, err := scanAll(t, "")
require.NoError(t, err)
require.Empty(t, blocks)
}
func TestSplitPEM_WhitespaceOnly(t *testing.T) {
blocks, err := scanAll(t, " \n\t\n ")
require.NoError(t, err)
require.Empty(t, blocks)
}
func TestSplitPEM_TrailingGarbage(t *testing.T) {
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage"
blocks, err := scanAll(t, input)
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
require.Len(t, blocks, 1)
}
func TestSplitPEM_TruncatedBlock(t *testing.T) {
input := "-----BEGIN TEST-----\npartial data with no end"
_, err := scanAll(t, input)
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
}
func TestSplitPEM_NoEndNewline(t *testing.T) {
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----"
blocks, err := scanAll(t, input)
require.NoError(t, err)
require.Len(t, blocks, 1)
require.Equal(t, input, blocks[0])
}
func TestSplitPEM_GarbageOnly(t *testing.T) {
_, err := scanAll(t, "this is not PEM data")
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
}
func TestUnmarshalCertificateFromPEM(t *testing.T) {
goodCert := []byte(`
# A good cert

View File

@@ -9,6 +9,8 @@ import (
"fmt"
"net/netip"
"time"
"github.com/slackhq/nebula/cert/p256"
)
// TBSCertificate represents a certificate intended to be signed.
@@ -126,6 +128,13 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb
return nil, err
}
if curve == Curve_P256 {
sig, err = p256.Normalize(sig)
if err != nil {
return nil, err
}
}
err = c.setSignature(sig)
if err != nil {
return nil, err

View File

@@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/slackhq/nebula/cert/p256"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -89,3 +90,48 @@ func TestCertificateV1_SignP256(t *testing.T) {
require.NoError(t, err)
assert.NotNil(t, uc)
}
func TestCertificate_SignP256_AlwaysNormalized(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
tbs := TBSCertificate{
Version: Version1,
Name: "testing",
Networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
UnsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: true,
Curve: Curve_P256,
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
rawPriv := priv.D.FillBytes(make([]byte, 32))
for i := 0; i < 1000; i++ {
if i&1 == 1 {
tbs.Version = Version1
} else {
tbs.Version = Version2
}
c, err := tbs.Sign(nil, Curve_P256, rawPriv)
require.NoError(t, err)
assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub))
normie, err := p256.IsNormalized(c.Signature())
require.NoError(t, err)
assert.True(t, normie)
}
}

View File

@@ -163,3 +163,55 @@ func P256Keypair() ([]byte, []byte) {
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
}
// DummyCert is a minimal cert.Certificate implementation for testing error paths.
type DummyCert struct {
Version_ cert.Version
Curve_ cert.Curve
Groups_ []string
IsCA_ bool
Issuer_ string
Name_ string
Networks_ []netip.Prefix
NotAfter_ time.Time
NotBefore_ time.Time
PublicKey_ []byte
Signature_ []byte
UnsafeNetworks_ []netip.Prefix
}
func (d *DummyCert) Version() cert.Version { return d.Version_ }
func (d *DummyCert) Curve() cert.Curve { return d.Curve_ }
func (d *DummyCert) Groups() []string { return d.Groups_ }
func (d *DummyCert) IsCA() bool { return d.IsCA_ }
func (d *DummyCert) Issuer() string { return d.Issuer_ }
func (d *DummyCert) Name() string { return d.Name_ }
func (d *DummyCert) Networks() []netip.Prefix { return d.Networks_ }
func (d *DummyCert) NotAfter() time.Time { return d.NotAfter_ }
func (d *DummyCert) NotBefore() time.Time { return d.NotBefore_ }
func (d *DummyCert) PublicKey() []byte { return d.PublicKey_ }
func (d *DummyCert) Signature() []byte { return d.Signature_ }
func (d *DummyCert) UnsafeNetworks() []netip.Prefix { return d.UnsafeNetworks_ }
func (d *DummyCert) Fingerprint() (string, error) { return "", nil }
func (d *DummyCert) CheckSignature(key []byte) bool { return false }
func (d *DummyCert) MarshalForHandshakes() ([]byte, error) { return nil, nil }
func (d *DummyCert) MarshalPEM() ([]byte, error) { return nil, nil }
func (d *DummyCert) MarshalJSON() ([]byte, error) { return nil, nil }
func (d *DummyCert) Marshal() ([]byte, error) { return nil, nil }
func (d *DummyCert) String() string { return "dummy" }
func (d *DummyCert) Copy() cert.Certificate { return d }
func (d *DummyCert) VerifyPrivateKey(c cert.Curve, k []byte) error { return nil }
func (d *DummyCert) Expired(time.Time) bool { return false }
func (d *DummyCert) MarshalPublicKeyPEM() []byte { return nil }
func (d *DummyCert) PublicKeyPEM() []byte { return nil }
// NewTestCAPool creates a CAPool from the given CA certificates, panicking on error.
func NewTestCAPool(cas ...cert.Certificate) *cert.CAPool {
pool := cert.NewCAPool()
for _, ca := range cas {
if err := pool.AddCA(ca); err != nil {
panic(err)
}
}
return pool
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"io"
"os"
"strings"
"time"
"github.com/slackhq/nebula/cert"
@@ -40,21 +39,15 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return err
}
rawCACert, err := os.ReadFile(*vf.caPath)
caFile, err := os.Open(*vf.caPath)
if err != nil {
return fmt.Errorf("error while reading ca: %w", err)
}
defer caFile.Close()
caPool := cert.NewCAPool()
for {
rawCACert, err = caPool.AddCAFromPEM(rawCACert)
if err != nil {
return fmt.Errorf("error while adding ca cert to pool: %w", err)
}
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
break
}
caPool, err := cert.NewCAPoolFromPEMReader(caFile)
if err != nil && !errors.Is(err, cert.ErrExpired) {
return fmt.Errorf("error while adding ca cert to pool: %w", err)
}
rawCert, err := os.ReadFile(*vf.certPath)

View File

@@ -64,7 +64,7 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
require.ErrorIs(t, err, cert.ErrInvalidPEMBlock)
// make a ca for later
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)

View File

@@ -3,8 +3,15 @@
package main
import "github.com/sirupsen/logrus"
import (
"log/slog"
"os"
func HookLogger(l *logrus.Logger) {
// Do nothing, let the logs flow to stdout/stderr
"github.com/slackhq/nebula/logging"
)
// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows
// platforms have no special sink to integrate with.
func newPlatformLogger() *slog.Logger {
return logging.NewLogger(os.Stdout)
}

View File

@@ -1,54 +1,86 @@
package main
import (
"fmt"
"io/ioutil"
"os"
"context"
"log/slog"
"strings"
"sync"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/logging"
)
// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer
// logrus output will be discarded
func HookLogger(l *logrus.Logger) {
l.AddHook(newLogHook(logger))
l.SetOutput(ioutil.Discard)
// newPlatformLogger returns a *slog.Logger that routes every log record
// through the Windows service logger so records end up in the Windows
// Event Log. All the heavy lifting (level management, format swap,
// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler;
// this file only contributes:
//
// - an io.Writer that forwards each formatted line to the service
// logger at the current record's Event Log severity, and
// - a thin severityTag that embeds *logging.Handler and overrides
// only Handle / WithAttrs / WithGroup, so Event Viewer's severity
// column and severity-based filters keep working the way they did
// before the slog migration.
//
// Format (text vs json) is carried by the embedded *logging.Handler, so
// logging.format: json in config still produces JSON lines in Event
// Viewer, same as the pre-slog logrus setup.
func newPlatformLogger() *slog.Logger {
w := &eventLogWriter{}
return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w})
}
type logHook struct {
sl service.Logger
// eventLogWriter forwards slog-formatted lines to the Windows service
// logger at the severity most recently stashed by severityTag.Handle.
// The mutex serializes the stash + inner.Handle + Write cycle per record
// across all concurrent goroutines; slog's builtin text/json handlers
// each hold their own mutex around Write, but that only protects the
// Write call itself, not our stash-then-handle sequence.
type eventLogWriter struct {
mu sync.Mutex
level slog.Level
}
func newLogHook(sl service.Logger) *logHook {
return &logHook{sl: sl}
}
func (h *logHook) Fire(entry *logrus.Entry) error {
line, err := entry.String()
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
return err
}
switch entry.Level {
case logrus.PanicLevel:
return h.sl.Error(line)
case logrus.FatalLevel:
return h.sl.Error(line)
case logrus.ErrorLevel:
return h.sl.Error(line)
case logrus.WarnLevel:
return h.sl.Warning(line)
case logrus.InfoLevel:
return h.sl.Info(line)
case logrus.DebugLevel:
return h.sl.Info(line)
func (w *eventLogWriter) Write(p []byte) (int, error) {
line := strings.TrimRight(string(p), "\n")
switch {
case w.level >= slog.LevelError:
return len(p), logger.Error(line)
case w.level >= slog.LevelWarn:
return len(p), logger.Warning(line)
default:
return nil
return len(p), logger.Info(line)
}
}
func (h *logHook) Levels() []logrus.Level {
return logrus.AllLevels
// severityTag embeds *logging.Handler to pick up everything it does for
// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat,
// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup
// so each record's slog.Level is stashed on the writer before formatting
// and so derived handlers stay wrapped as severityTag rather than
// downgrading to bare *logging.Handler.
type severityTag struct {
*logging.Handler
w *eventLogWriter
}
func (s *severityTag) Handle(ctx context.Context, r slog.Record) error {
s.w.mu.Lock()
defer s.w.mu.Unlock()
s.w.level = r.Level
return s.Handler.Handle(ctx, r)
}
func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return s
}
return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w}
}
func (s *severityTag) WithGroup(name string) slog.Handler {
if name == "" {
return s
}
return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w}
}

View File

@@ -7,9 +7,9 @@ import (
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/util"
)
@@ -50,9 +50,14 @@ func main() {
os.Exit(0)
}
l := logging.NewLogger(os.Stdout)
if *serviceFlag != "" {
doService(configPath, configTest, Build, serviceFlag)
os.Exit(1)
if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
l.Error("Service command failed", "error", err)
os.Exit(1)
}
return
}
if *configPath == "" {
@@ -61,9 +66,6 @@ func main() {
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c := config.NewC(l)
err := c.Load(*configPath)
if err != nil {
@@ -71,6 +73,16 @@ func main() {
os.Exit(1)
}
if err := logging.ApplyConfig(l, c); err != nil {
fmt.Printf("failed to apply logging config: %s", err)
os.Exit(1)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
@@ -78,8 +90,20 @@ func main() {
}
if !*configTest {
ctrl.Start()
ctrl.ShutdownBlock()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
if err := wait(); err != nil {
l.Error("Nebula stopped due to fatal error", "error", err)
os.Exit(2)
}
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -7,9 +7,9 @@ import (
"path/filepath"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
)
var logger service.Logger
@@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error {
// Start should not block.
logger.Info("Nebula service starting.")
l := logrus.New()
HookLogger(l)
l := newPlatformLogger()
c := config.NewC(l)
err := c.Load(*p.configPath)
@@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error {
return fmt.Errorf("failed to load config: %s", err)
}
if err := logging.ApplyConfig(l, c); err != nil {
return fmt.Errorf("failed to apply logging config: %s", err)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
if err != nil {
return err
@@ -57,11 +65,11 @@ func fileExists(filename string) bool {
return true
}
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
if *configPath == "" {
ex, err := os.Executable()
if err != nil {
panic(err)
return err
}
*configPath = filepath.Dir(ex) + "/config.yaml"
if !fileExists(*configPath) {
@@ -85,16 +93,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
// Here are what the different loggers are doing:
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use
// - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger
s, err := service.New(prg, svcConfig)
if err != nil {
log.Fatal(err)
return err
}
errs := make(chan error, 5)
logger, err = s.Logger(errs)
if err != nil {
log.Fatal(err)
return err
}
go func() {
@@ -109,18 +117,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
switch *serviceFlag {
case "run":
err = s.Run()
if err != nil {
if err := s.Run(); err != nil {
// Route any errors to the system logger
logger.Error(err)
}
default:
err := service.Control(s, *serviceFlag)
if err != nil {
if err := service.Control(s, *serviceFlag); err != nil {
log.Printf("Valid actions: %q\n", service.ControlAction)
log.Fatal(err)
return err
}
return
}
return nil
}

View File

@@ -7,9 +7,9 @@ import (
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/util"
)
@@ -55,8 +55,7 @@ func main() {
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
l := logging.NewLogger(os.Stdout)
c := config.NewC(l)
err := c.Load(*configPath)
@@ -65,6 +64,16 @@ func main() {
os.Exit(1)
}
if err := logging.ApplyConfig(l, c); err != nil {
fmt.Printf("failed to apply logging config: %s", err)
os.Exit(1)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
@@ -72,9 +81,21 @@ func main() {
}
if !*configTest {
ctrl.Start()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
notifyReady(l)
ctrl.ShutdownBlock()
if err := wait(); err != nil {
l.Error("Nebula stopped due to fatal error", "error", err)
os.Exit(2)
}
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -1,11 +1,10 @@
package main
import (
"log/slog"
"net"
"os"
"time"
"github.com/sirupsen/logrus"
)
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
@@ -13,30 +12,30 @@ import (
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
const SdNotifyReady = "READY=1"
func notifyReady(l *logrus.Logger) {
func notifyReady(l *slog.Logger) {
sockName := os.Getenv("NOTIFY_SOCKET")
if sockName == "" {
l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
return
}
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
if err != nil {
l.WithError(err).Error("failed to connect to systemd notification socket")
l.Error("failed to connect to systemd notification socket", "error", err)
return
}
defer conn.Close()
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
if err != nil {
l.WithError(err).Error("failed to set the write deadline for the systemd notification socket")
l.Error("failed to set the write deadline for the systemd notification socket", "error", err)
return
}
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
l.WithError(err).Error("failed to signal the systemd notification socket")
l.Error("failed to signal the systemd notification socket", "error", err)
return
}
l.Debugln("notified systemd the service is ready")
l.Debug("notified systemd the service is ready")
}

View File

@@ -3,8 +3,8 @@
package main
import "github.com/sirupsen/logrus"
import "log/slog"
func notifyReady(_ *logrus.Logger) {
func notifyReady(_ *slog.Logger) {
// No init service to notify
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"math"
"os"
"os/signal"
@@ -16,7 +17,6 @@ import (
"time"
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"go.yaml.in/yaml/v3"
)
@@ -26,11 +26,11 @@ type C struct {
Settings map[string]any
oldSettings map[string]any
callbacks []func(*C)
l *logrus.Logger
l *slog.Logger
reloadLock sync.Mutex
}
func NewC(l *logrus.Logger) *C {
func NewC(l *slog.Logger) *C {
return &C{
Settings: make(map[string]any),
l: l,
@@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool {
newVals, err := yaml.Marshal(nv)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
c.l.Error("Error while marshaling new config",
"config_path", k,
"error", err,
)
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
c.l.Error("Error while marshaling old config",
"config_path", k,
"error", err,
)
}
return string(newVals) != string(oldVals)
@@ -154,7 +160,10 @@ func (c *C) ReloadConfig() {
err := c.Load(c.path)
if err != nil {
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
c.l.Error("Error occurred while reloading config",
"config_path", c.path,
"error", err,
)
return
}

View File

@@ -5,13 +5,13 @@ import (
"context"
"encoding/binary"
"fmt"
"log/slog"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
@@ -47,10 +47,10 @@ type connectionManager struct {
metricsTxPunchy metrics.Counter
l *logrus.Logger
l *slog.Logger
}
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
cm := &connectionManager{
hostMap: hm,
l: l,
@@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
old := cm.getInactivityTimeout()
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
if !initial {
cm.l.WithField("oldDuration", old).
WithField("newDuration", cm.getInactivityTimeout()).
Info("Inactivity timeout has changed")
cm.l.Info("Inactivity timeout has changed",
"oldDuration", old,
"newDuration", cm.getInactivityTimeout(),
)
}
}
@@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
old := cm.dropInactive.Load()
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
if !initial {
cm.l.WithField("oldBool", old).
WithField("newBool", cm.dropInactive.Load()).
Info("Drop inactive setting has changed")
cm.l.Info("Drop inactive setting has changed",
"oldBool", old,
"newBool", cm.dropInactive.Load(),
)
}
}
}
@@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
var err error
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil {
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
cm.l.Error("failed to migrate relay to new hostinfo", "error", err)
continue
}
switch r.Type {
@@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
msg, err := req.Marshal()
if err != nil {
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
cm.l.Error("failed to marshal Control message to migrate relay", "error", err)
} else {
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
cm.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnAddrs": newhostinfo.vpnAddrs}).
Info("send CreateRelayRequest")
cm.l.Info("send CreateRelayRequest",
"relayFrom", req.RelayFromAddr,
"relayTo", req.RelayToAddr,
"initiatorRelayIndex", req.InitiatorRelayIndex,
"responderRelayIndex", req.ResponderRelayIndex,
"vpnAddrs", newhostinfo.vpnAddrs,
)
}
}
}
@@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
hostinfo := cm.hostMap.Indexes[localIndex]
if hostinfo == nil {
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
cm.l.Debug("Not found in hostmap", "localIndex", localIndex)
return doNothing, nil, nil
}
@@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
// A hostinfo is determined alive if there is incoming traffic
if inTraffic {
decision := doNothing
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Tunnel status",
"tunnelCheck", m{"state": "alive", "method": "passive"},
)
}
hostinfo.pendingDeletion.Store(false)
@@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if hostinfo.pendingDeletion.Load() {
// We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status")
hostinfo.logger(cm.l).Info("Tunnel status",
"tunnelCheck", m{"state": "dead", "method": "active"},
)
return deleteTunnel, hostinfo, nil
}
@@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
if isInactive {
// Tunnel is inactive, tear it down
hostinfo.logger(cm.l).
WithField("inactiveDuration", inactiveFor).
WithField("primary", mainHostInfo).
Info("Dropping tunnel due to inactivity")
hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity",
"inactiveDuration", inactiveFor,
"primary", mainHostInfo,
)
return closeTunnel, hostinfo, primary
}
@@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
cm.sendPunch(hostinfo)
}
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Tunnel status",
"tunnelCheck", m{"state": "testing", "method": "active"},
)
}
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
decision = sendTestPacket
} else {
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Hostinfo sadness")
}
}
@@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
return false //cert is still valid! yay!
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
// Block listed certificates should always be disconnected
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is blocked, tearing down the tunnel")
hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel",
"error", err,
"fingerprint", remoteCert.Fingerprint,
)
return true
} else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel",
"error", err,
"fingerprint", remoteCert.Fingerprint,
)
return true
} else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
@@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
curCrtVersion := curCrt.Version()
myCrt := cs.getCertificate(curCrtVersion)
if myCrt == nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"version", curCrtVersion,
"reason", "local certificate removed",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
@@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version lower than peer, attempting to correct").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"version", curCrtVersion,
"peerVersion", peerCrt.Certificate.Version(),
"reason", "local certificate version lower than peer, attempting to correct",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
@@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
}
}
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"reason", "local certificate is not current",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "current cert version < pki.initiatingVersion").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"reason", "current cert version < pki.initiatingVersion",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return

View File

@@ -7,9 +7,9 @@ import (
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/overlaytest"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
@@ -46,13 +46,13 @@ func Test_NewConnectionManagerTest(t *testing.T) {
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
v1Credential: nil,
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -63,9 +63,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
nb := make([]byte, 12, 12)
@@ -79,7 +79,6 @@ func Test_NewConnectionManagerTest(t *testing.T) {
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -129,13 +128,13 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
v1Credential: nil,
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -146,9 +145,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
nb := make([]byte, 12, 12)
@@ -162,7 +161,6 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -214,13 +212,13 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
v1Credential: nil,
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -231,12 +229,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
conf := config.NewC(test.NewLogger())
conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
@@ -248,7 +246,6 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -339,15 +336,15 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
cs := &CertState{
privateKey: []byte{},
v1Cert: &dummyCert{},
v1HandshakeBytes: []byte{},
privateKey: []byte{},
v1Cert: &dummyCert{},
v1Credential: nil,
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -360,9 +357,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.disconnectInvalid.Store(true)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
ifce.connectionManager = nc
@@ -371,7 +368,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ConnectionState: &ConnectionState{
myCert: &dummyCert{},
peerCert: cachedPeerCert,
H: &noise.HandshakeState{},
},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)

View File

@@ -1,16 +1,12 @@
package nebula
import (
"crypto/rand"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/handshake"
)
const ReplayWindow = 1024
@@ -18,7 +14,6 @@ const ReplayWindow = 1024
type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
H *noise.HandshakeState
myCert cert.Certificate
peerCert *cert.CachedCertificate
initiator bool
@@ -27,55 +22,24 @@ type ConnectionState struct {
writeLock sync.Mutex
}
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch crt.Curve() {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
if cs.pkcs11Backed {
dhFunc = noiseutil.DHP256PKCS11
} else {
dhFunc = noiseutil.DHP256
}
default:
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
}
var ncs noise.CipherSuite
if cs.cipher == "chachapoly" {
ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else {
ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
}
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
//NOTE: These should come from CertState (pki.go) when we finally implement it
PresharedKey: []byte{},
PresharedKeyPlacement: 0,
})
if err != nil {
return nil, fmt.Errorf("NewConnectionState: %s", err)
}
// The queue and ready params prevent a counter race that would happen when
// sending stored packets and simultaneously accepting new traffic.
// newConnectionStateFromResult builds a fully-populated ConnectionState from a
// completed handshake.Result. It seeds messageCounter and the replay window so
// that the post-handshake message indices already used on the wire don't count
// as missed traffic in the data plane.
func newConnectionStateFromResult(r *handshake.Result) *ConnectionState {
ci := &ConnectionState{
H: hs,
initiator: initiator,
myCert: r.MyCert,
initiator: r.Initiator,
peerCert: r.RemoteCert,
eKey: NewNebulaCipherState(r.EKey),
dKey: NewNebulaCipherState(r.DKey),
window: NewBits(ReplayWindow),
myCert: crt,
}
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2)
return ci, nil
ci.messageCounter.Add(r.MessageIndex)
for i := uint64(1); i <= r.MessageIndex; i++ {
ci.window.Update(nil, i)
}
return ci
}
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {

114
connection_state_test.go Normal file
View File

@@ -0,0 +1,114 @@
package nebula
import (
"net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
ct "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/handshake"
"github.com/slackhq/nebula/header"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// runTestHandshake runs a complete IX handshake between two freshly-built
// peers and returns the initiator and responder Results. Used to produce
// real cipher states for tests that need to exercise post-handshake glue.
func runTestHandshake(t *testing.T) (initR, respR *handshake.Result) {
t.Helper()
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
makeCreds := func(name string, networks []netip.Prefix) handshake.GetCredentialFunc {
c, _, rawKey, _ := ct.NewTestCert(
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
)
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawKey)
require.NoError(t, err)
hsBytes, err := c.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
cred := handshake.NewCredential(c, hsBytes, priv, ncs)
return func(v cert.Version) *handshake.Credential {
if v == cert.Version2 {
return cred
}
return nil
}
}
verifier := func(c cert.Certificate) (*cert.CachedCertificate, error) {
return caPool.VerifyCertificate(time.Now(), c)
}
initCreds := makeCreds("initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCreds := makeCreds("responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initM, err := handshake.NewMachine(
cert.Version2, initCreds, verifier,
func() (uint32, error) { return 1000, nil },
true, header.HandshakeIXPSK0,
)
require.NoError(t, err)
respM, err := handshake.NewMachine(
cert.Version2, respCreds, verifier,
func() (uint32, error) { return 2000, nil },
false, header.HandshakeIXPSK0,
)
require.NoError(t, err)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
resp, respR, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
require.NotNil(t, respR)
_, initR, err = initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, initR)
return initR, respR
}
func TestNewConnectionStateFromResult(t *testing.T) {
initR, respR := runTestHandshake(t)
t.Run("initiator", func(t *testing.T) {
ci := newConnectionStateFromResult(initR)
assert.True(t, ci.initiator)
assert.Equal(t, initR.MyCert, ci.myCert)
assert.Equal(t, initR.RemoteCert, ci.peerCert)
assert.NotNil(t, ci.eKey)
assert.NotNil(t, ci.dKey)
// IX has 2 handshake messages; the next data-plane send is counter=3.
assert.Equal(t, uint64(2), ci.messageCounter.Load(),
"messageCounter must equal Result.MessageIndex so the next send is N+1")
// Both handshake counters must be marked seen so they don't appear lost.
// Check returns false if an index has already been recorded.
assert.False(t, ci.window.Check(nil, 1), "counter 1 must already be seen")
assert.False(t, ci.window.Check(nil, 2), "counter 2 must already be seen")
// Counter 3 is the next data-plane message and must NOT be pre-marked.
assert.True(t, ci.window.Check(nil, 3), "counter 3 must not be pre-seeded")
})
t.Run("responder", func(t *testing.T) {
ci := newConnectionStateFromResult(respR)
assert.False(t, ci.initiator)
assert.Equal(t, respR.MyCert, ci.myCert)
assert.Equal(t, respR.RemoteCert, ci.peerCert)
assert.NotNil(t, ci.eKey)
assert.NotNil(t, ci.dKey)
assert.Equal(t, uint64(2), ci.messageCounter.Load())
})
}

View File

@@ -2,17 +2,33 @@ package nebula
import (
"context"
"errors"
"log/slog"
"net/netip"
"os"
"os/signal"
"sync"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
)
type RunState int
const (
StateUnknown RunState = iota
StateReady
StateStarted
StateStopping
StateStopped
)
var ErrAlreadyStarted = errors.New("nebula is already started")
var ErrAlreadyStopped = errors.New("nebula cannot be restarted")
var ErrUnknownState = errors.New("nebula state is invalid")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -26,8 +42,11 @@ type controlHostLister interface {
}
type Control struct {
stateLock sync.Mutex
state RunState
f *Interface
l *logrus.Logger
l *slog.Logger
ctx context.Context
cancel context.CancelFunc
sshStart func()
@@ -49,10 +68,31 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Start actually runs nebula, this is a nonblocking call.
// The returned function blocks until nebula has fully stopped and returns the
// first fatal reader error (if any). A nil error means nebula shut down
// gracefully; a non-nil error means a reader hit an unexpected failure that
// triggered the shutdown.
func (c *Control) Start() (func() error, error) {
c.stateLock.Lock()
defer c.stateLock.Unlock()
switch c.state {
case StateReady:
//yay!
case StateStopped, StateStopping:
return nil, ErrAlreadyStopped
case StateStarted:
return nil, ErrAlreadyStarted
default:
return nil, ErrUnknownState
}
// Activate the interface
c.f.activate()
err := c.f.activate()
if err != nil {
c.state = StateStopped
return nil, err
}
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
@@ -71,25 +111,51 @@ func (c *Control) Start() {
c.lighthouseStart()
}
c.f.triggerShutdown = c.Stop
// Start reading packets.
c.f.run()
out, err := c.f.run()
if err != nil {
c.state = StateStopped
return nil, err
}
c.state = StateStarted
return out, nil
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
}
func (c *Control) Context() context.Context {
return c.ctx
}
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != StateStarted {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = StateStopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down.
c.cancel()
c.CloseAllTunnels(false)
if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed")
c.l.Error("Close interface failed", "error", err)
}
c.l.Info("Goodbye")
c.stateLock.Lock()
c.state = StateStopped
c.stateLock.Unlock()
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
@@ -100,7 +166,7 @@ func (c *Control) ShutdownBlock() {
rawSig := <-sigChan
sig := rawSig.String()
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
c.l.Info("Caught signal, shutting down", "signal", sig)
c.Stop()
}
@@ -237,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h)
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
c.l.Debug("Sending close tunnel message",
"vpnAddrs", h.vpnAddrs,
"udpAddr", h.remote,
)
closed++
}

View File

@@ -6,7 +6,6 @@ import (
"reflect"
"testing"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
@@ -79,10 +78,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}, &Interface{})
c := Control{
state: StateReady,
f: &Interface{
hostMap: hm,
},
l: logrus.New(),
l: test.NewLogger(),
}
thi := c.GetHostInfoByVpnAddr(vpnIp, false)

View File

@@ -1,13 +1,10 @@
//go:build e2e_testing
// +build e2e_testing
package nebula
import (
"net/netip"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
@@ -23,7 +20,9 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message
panic(err)
}
pipeTo.InjectUDPPacket(p)
if h.Type == msgType && h.Subtype == subType {
match := h.Type == msgType && h.Subtype == subType
p.Release()
if match {
return
}
}
@@ -39,7 +38,9 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
panic(err)
}
pipeTo.InjectUDPPacket(p)
if h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType {
match := h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType
p.Release()
if match {
return
}
}
@@ -91,65 +92,15 @@ func (c *Control) GetTunTxChan() <-chan []byte {
return c.f.inside.(*overlay.TestTun).TxPackets
}
// InjectUDPPacket will inject a packet into the udp side of nebula
// InjectUDPPacket injects a packet into the udp side. We copy internally so the caller keeps ownership of p.
// The copy comes from the freelist so steady-state alloc is zero.
func (c *Control) InjectUDPPacket(p *udp.Packet) {
c.f.outside.(*udp.TesterConn).Send(p)
c.f.outside.(*udp.TesterConn).Send(p.Copy())
}
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
serialize := make([]gopacket.SerializableLayer, 0)
var netLayer gopacket.NetworkLayer
if toAddr.Is6() {
if !fromAddr.Is6() {
panic("Cant send ipv6 to ipv4")
}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
}
udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
err := udp.SetNetworkLayerForChecksum(netLayer)
if err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
serialize = append(serialize, &udp, gopacket.Payload(data))
err = gopacket.SerializeLayers(buffer, opt, serialize...)
if err != nil {
panic(err)
}
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
// InjectTunPacket pushes an IP packet onto the tun interface.
func (c *Control) InjectTunPacket(packet []byte) {
c.f.inside.(*overlay.TestTun).Send(packet)
}
func (c *Control) GetVpnAddrs() []netip.Addr {

View File

@@ -84,30 +84,24 @@ end
function nebula.prefs_changed()
if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then
-- Nothing changed, bail
return
end
-- Remove our old dissector
-- Remove all existing registrations
DissectorTable.get("udp.port"):remove_all(nebula)
if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then
default_settings.all_port = nebula.prefs.all_ports
if nebula.prefs.all_ports then
-- Register on every port for hole punch capture
for i=0, 65535 do
DissectorTable.get("udp.port"):add(i, nebula)
end
-- no need to establish again on specific ports
return
else
-- Register on the configured port only
DissectorTable.get("udp.port"):add(nebula.prefs.port, nebula)
end
if default_settings.all_ports ~= nebula.prefs.all_ports then
-- Add our new port dissector
default_settings.port = nebula.prefs.port
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
end
default_settings.all_ports = nebula.prefs.all_ports
default_settings.port = nebula.prefs.port
end
DissectorTable.get("udp.port"):add(default_settings.port, nebula)

View File

@@ -1,63 +1,249 @@
package nebula
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
// This whole thing should be rewritten to use context
var dnsR *dnsRecords
var dnsServer *dns.Server
var dnsAddr string
type dnsRecords struct {
type dnsServer struct {
sync.RWMutex
l *logrus.Logger
l *slog.Logger
ctx context.Context
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Lite
mux *dns.ServeMux
// enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`.
// Start, Add, and reload consult it so callers don't need to know the
// gating rules. When it toggles off via reload, accumulated records are
// cleared so a later re-enable starts with a fresh map populated from
// new handshakes.
enabled atomic.Bool
serverMu sync.Mutex
server *dns.Server
// started is closed once `server` has finished binding (or after
// ListenAndServe returns on a bind failure). Stop waits on it before
// calling Shutdown to avoid the miekg/dns "server not started" race
// where a Shutdown that arrives before bind completes is silently
// ignored, leaving the listener running forever.
started chan struct{}
addr string
}
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
return &dnsRecords{
// newDnsServerFromConfig builds a dnsServer, applies the initial config, and
// registers a reload callback. The reload callback is registered before the
// initial config is applied, so a SIGHUP can later enable, fix, or disable
// DNS even if the initial application failed.
//
// The dnsServer internally gates on `lighthouse.serve_dns &&
// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally,
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
// watcher that tears the listener down on nebula shutdown. The returned
// pointer is always non-nil, even on error.
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
ds := &dnsServer{
l: l,
ctx: ctx,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
}
ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest)
c.RegisterReloadCallback(func(c *config.C) {
if err := ds.reload(c, false); err != nil {
ds.l.Error("Failed to reload DNS responder from config", "error", err)
}
})
if err := ds.reload(c, true); err != nil {
return ds, err
}
return ds, nil
}
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
// reload applies the latest config and reconciles the running state with it:
// - enabled toggled on -> spawn a runner
// - enabled toggled off -> stop the runner
// - listen address changed (while running) -> restart on the new address
// - everything else -> no-op
//
// On the initial call it only records configuration; Control.Start is what
// launches the first runner via dnsStart.
func (d *dnsServer) reload(c *config.C, initial bool) error {
wantsDns := c.GetBool("lighthouse.serve_dns", false)
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
enabled := wantsDns && amLighthouse
newAddr := getDnsServerAddr(c)
d.serverMu.Lock()
running := d.server
runningStarted := d.started
sameAddr := d.addr == newAddr
d.addr = newAddr
d.enabled.Store(enabled)
d.serverMu.Unlock()
if initial {
if wantsDns && !amLighthouse {
d.l.Warn("DNS server refusing to run because this host is not a lighthouse.")
}
return nil
}
if !enabled {
if running != nil {
d.Stop()
}
// Drop any records that accumulated while enabled; a later re-enable
// will repopulate from fresh handshakes.
d.clearRecords()
return nil
}
if running == nil {
// Was disabled (or never started); bring it up now.
go d.Start()
return nil
}
if sameAddr {
return nil
}
d.shutdownServer(running, runningStarted, "reload")
// Old Start goroutine has now exited; bring up a fresh listener on the
// new address.
go d.Start()
return nil
}
// shutdownServer waits for the server to finish binding (so Shutdown actually
// stops it rather than no-oping) and then shuts it down.
func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) {
if srv == nil {
return
}
if started != nil {
<-started
}
if err := srv.Shutdown(); err != nil {
d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err)
}
}
// Start binds and serves the DNS responder. Blocks until Stop is called or
// the listener errors. Safe to call when DNS is disabled (returns
// immediately). This is what Control.dnsStart points at.
//
// Must be invoked after the tun device is active so that lighthouse.dns.host
// may bind to a nebula IP.
func (d *dnsServer) Start() {
if !d.enabled.Load() {
return
}
started := make(chan struct{})
d.serverMu.Lock()
if d.ctx.Err() != nil {
d.serverMu.Unlock()
return
}
addr := d.addr
server := &dns.Server{
Addr: addr,
Net: "udp",
Handler: d.mux,
NotifyStartedFunc: func() { close(started) },
}
d.server = server
d.started = started
d.serverMu.Unlock()
// Per-invocation ctx watcher. Exits when Start does, so we don't leak a
// watcher per reload-driven restart.
done := make(chan struct{})
go func() {
select {
case <-d.ctx.Done():
d.shutdownServer(server, started, "shutdown")
case <-done:
}
}()
d.l.Info("Starting DNS responder", "dnsListener", addr)
err := server.ListenAndServe()
close(done)
// If the listener never bound (bind error) NotifyStartedFunc never fires,
// so close started here to release any Stop caller waiting on it.
select {
case <-started:
default:
close(started)
}
if err != nil {
d.l.Warn("Failed to run the DNS responder", "error", err)
}
}
// Stop shuts down the active server, if any. Idempotent.
func (d *dnsServer) Stop() {
d.serverMu.Lock()
srv := d.server
started := d.started
d.server = nil
d.started = nil
d.serverMu.Unlock()
d.shutdownServer(srv, started, "stop")
}
// Query returns the address for the given name and query type. The second
// return value reports whether the name is known at all (in either A or AAAA),
// which lets callers distinguish NODATA from NXDOMAIN.
func (d *dnsServer) Query(q uint16, data string) (netip.Addr, bool) {
data = strings.ToLower(data)
d.RLock()
defer d.RUnlock()
addr4, haveV4 := d.dnsMap4[data]
addr6, haveV6 := d.dnsMap6[data]
nameExists := haveV4 || haveV6
switch q {
case dns.TypeA:
if r, ok := d.dnsMap4[data]; ok {
return r
if haveV4 {
return addr4, nameExists
}
case dns.TypeAAAA:
if r, ok := d.dnsMap6[data]; ok {
return r
if haveV6 {
return addr6, nameExists
}
}
return netip.Addr{}
return netip.Addr{}, nameExists
}
func (d *dnsRecords) QueryCert(data string) string {
func (d *dnsServer) QueryCert(data string) string {
if len(data) < 2 {
return ""
}
ip, err := netip.ParseAddr(data[:len(data)-1])
if err != nil {
return ""
@@ -80,8 +266,19 @@ func (d *dnsRecords) QueryCert(data string) string {
return string(b)
}
// clearRecords drops all DNS records.
func (d *dnsServer) clearRecords() {
d.Lock()
defer d.Unlock()
clear(d.dnsMap4)
clear(d.dnsMap6)
}
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
func (d *dnsServer) Add(host string, addresses []netip.Addr) {
if !d.enabled.Load() {
return
}
host = strings.ToLower(host)
d.Lock()
defer d.Unlock()
@@ -101,7 +298,7 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
}
}
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
a, _, _ := net.SplitHostPort(addr)
b, err := netip.ParseAddr(a)
if err != nil {
@@ -116,13 +313,24 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
return d.myVpnAddrsTable.Contains(b)
}
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug)
// Per RFC 2308 §2.2, a name that exists but has no record of the requested
// type must be answered with NOERROR and an empty answer section (NODATA),
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
// exist at all.
anyNameExists := false
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name)
ip := d.Query(q.Qtype, q.Name)
if debugEnabled {
d.l.Debug("DNS query", "type", qType, "name", q.Name)
}
ip, nameExists := d.Query(q.Qtype, q.Name)
if nameExists {
anyNameExists = true
}
if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil {
@@ -134,7 +342,9 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return
}
d.l.Debugf("Query for TXT %s", q.Name)
if debugEnabled {
d.l.Debug("DNS query", "type", "TXT", "name", q.Name)
}
ip := d.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
@@ -145,12 +355,12 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
}
}
if len(m.Answer) == 0 {
if len(m.Answer) == 0 && !anyNameExists {
m.Rcode = dns.RcodeNameError
}
}
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
@@ -163,21 +373,6 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m)
}
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(l, cs, hostMap)
// attach request handler func
dns.HandleFunc(".", dnsR.handleDnsRequest)
c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c)
})
return func() {
startDns(l, c)
}
}
func getDnsServerAddr(c *config.C) string {
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
@@ -186,25 +381,3 @@ func getDnsServerAddr(c *config.C) string {
}
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
}
func startDns(l *logrus.Logger, c *config.C) {
dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
err := dnsServer.ListenAndServe()
defer dnsServer.Shutdown()
if err != nil {
l.Errorf("Failed to start server: %s\n ", err.Error())
}
}
func reloadDns(l *logrus.Logger, c *config.C) {
if dnsAddr == getDnsServerAddr(c) {
l.Debug("No DNS server config change detected")
return
}
l.Debug("Restarting DNS server")
dnsServer.Shutdown()
go startDns(l, c)
}

View File

@@ -1,19 +1,43 @@
package nebula
import (
"context"
"log/slog"
"net"
"net/netip"
"strconv"
"testing"
"time"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type stubDNSWriter struct{}
func (stubDNSWriter) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (stubDNSWriter) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5353}
}
func (stubDNSWriter) Write([]byte) (int, error) { return 0, nil }
func (stubDNSWriter) WriteMsg(*dns.Msg) error { return nil }
func (stubDNSWriter) Close() error { return nil }
func (stubDNSWriter) TsigStatus() error { return nil }
func (stubDNSWriter) TsigTimersOnly(bool) {}
func (stubDNSWriter) Hijack() {}
func TestParsequery(t *testing.T) {
l := logrus.New()
l := slog.New(slog.DiscardHandler)
hostMap := &HostMap{}
ds := newDnsRecords(l, &CertState{}, hostMap)
ds := &dnsServer{
l: l,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
}
ds.enabled.Store(true)
addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"),
@@ -21,18 +45,56 @@ func TestParsequery(t *testing.T) {
netip.MustParseAddr("fd01::25"),
}
ds.Add("test.com.com", addrs)
ds.Add("v4only.com.com", []netip.Addr{netip.MustParseAddr("1.2.3.6")})
ds.Add("v6only.com.com", []netip.Addr{netip.MustParseAddr("fd01::26")})
m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
// A known name with no record of the requested type should return NODATA
// (NOERROR with empty answer), not NXDOMAIN.
m = &dns.Msg{}
m.SetQuestion("v4only.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
m = &dns.Msg{}
m.SetQuestion("v6only.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
// An unknown name should still return NXDOMAIN.
m = &dns.Msg{}
m.SetQuestion("unknown.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeNameError, m.Rcode)
// short lookups should not fail
m = &dns.Msg{}
m.Question = []dns.Question{{Name: "", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
ds.parseQuery(m, stubDNSWriter{})
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeNameError, m.Rcode)
m = &dns.Msg{}
m.Question = []dns.Question{{Name: ".", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
ds.parseQuery(m, stubDNSWriter{})
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeNameError, m.Rcode)
}
func Test_getDnsServerAddr(t *testing.T) {
@@ -71,3 +133,208 @@ func Test_getDnsServerAddr(t *testing.T) {
}
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
}
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
t.Helper()
sl := slog.New(slog.DiscardHandler)
ds := &dnsServer{
l: sl,
ctx: context.Background(),
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: &HostMap{},
}
ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest)
return ds, config.NewC(nil)
}
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
c.Settings["lighthouse"] = map[string]any{
"am_lighthouse": amLighthouse,
"serve_dns": serveDns,
"dns": map[string]any{
"host": host,
"port": port,
},
}
}
func TestDnsServer_reload_initial_disabled(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", true, false)
require.NoError(t, ds.reload(c, true))
assert.False(t, ds.enabled.Load())
assert.Equal(t, "127.0.0.1:0", ds.addr)
assert.Nil(t, ds.server)
}
func TestDnsServer_reload_initial_enabled(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
assert.True(t, ds.enabled.Load())
assert.Equal(t, "127.0.0.1:0", ds.addr)
// initial never starts a runner; that's Control.Start's job
assert.Nil(t, ds.server)
}
func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", false, true)
require.NoError(t, ds.reload(c, true))
// Wants DNS but isn't a lighthouse: gated off, no runner.
assert.False(t, ds.enabled.Load())
}
func TestDnsServer_reload_sameAddr_noOp(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
// No server running yet, no addr change. Reload should not spawn anything.
require.NoError(t, ds.reload(c, false))
assert.True(t, ds.enabled.Load())
assert.Nil(t, ds.server)
}
func TestDnsServer_StartStop_lifecycle(t *testing.T) {
// Bind to a real (random) UDP port so we exercise the actual
// ListenAndServe + Shutdown plumbing including the started-chan race fix.
port := freeUDPPort(t)
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", port, true, true)
require.NoError(t, ds.reload(c, true))
done := make(chan struct{})
go func() {
ds.Start()
close(done)
}()
waitFor(t, func() bool {
ds.serverMu.Lock()
started := ds.started
ds.serverMu.Unlock()
if started == nil {
return false
}
select {
case <-started:
return true
default:
return false
}
})
ds.Stop()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Start did not return after Stop")
}
}
func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
// Stop called immediately after Start should not deadlock even if bind
// hasn't completed yet. This exercises the started-chan close-on-bind-fail
// path: by binding to an obviously bad port (privileged) we get a fast
// bind error before NotifyStartedFunc fires.
ds, c := newTestDnsServer(t)
// Use a port that should fail to bind (negative would be invalid, use a
// host that won't resolve to ensure listenUDP fails quickly).
setDnsConfig(c, "256.256.256.256", "53", true, true)
require.NoError(t, ds.reload(c, true))
done := make(chan struct{})
go func() {
ds.Start()
close(done)
}()
// Give Start a moment to attempt the bind and fail.
select {
case <-done:
// Bind failed and Start returned; Stop should be a no-op.
case <-time.After(time.Second):
t.Fatal("Start did not return after a bad bind")
}
stopped := make(chan struct{})
go func() {
ds.Stop()
close(stopped)
}()
select {
case <-stopped:
case <-time.After(time.Second):
t.Fatal("Stop hung after a failed bind")
}
}
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
port := freeUDPPort(t)
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", port, true, true)
require.NoError(t, ds.reload(c, true))
startReturned := make(chan struct{})
go func() {
ds.Start()
close(startReturned)
}()
waitForBind(t, ds)
// Toggle serve_dns off; reload should shut the running server down.
setDnsConfig(c, "127.0.0.1", port, true, false)
require.NoError(t, ds.reload(c, false))
select {
case <-startReturned:
case <-time.After(5 * time.Second):
t.Fatal("Start did not return after reload disabled DNS")
}
assert.False(t, ds.enabled.Load())
}
func freeUDPPort(t *testing.T) string {
t.Helper()
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
port := conn.LocalAddr().(*net.UDPAddr).Port
require.NoError(t, conn.Close())
return strconv.Itoa(port)
}
func waitForBind(t *testing.T, ds *dnsServer) {
t.Helper()
waitFor(t, func() bool {
ds.serverMu.Lock()
started := ds.started
ds.serverMu.Unlock()
if started == nil {
return false
}
select {
case <-started:
return true
default:
return false
}
})
}
func waitFor(t *testing.T, cond func() bool) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatal("timed out waiting for condition")
}

View File

@@ -0,0 +1,577 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
)
// makeHandshakePacket creates a handshake packet with the given parameters.
func makeHandshakePacket(from, to netip.AddrPort, subtype header.MessageSubType, remoteIndex uint32, counter uint64) *udp.Packet {
data := make([]byte, 200)
header.Encode(data, header.Version, header.Handshake, subtype, remoteIndex, counter)
for i := header.Len; i < len(data); i++ {
data[i] = byte(i)
}
return &udp.Packet{To: to, From: from, Data: data}
}
func TestHandshakeRetransmitDuplicate(t *testing.T) {
t.Parallel()
// Verify the responder correctly handles receiving the same msg1 multiple times
// (retransmission). The duplicate goes through CheckAndComplete -> ErrAlreadySeen
// and the cached response is resent.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Trigger handshake from me to them")
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
t.Log("Grab my msg1")
msg1 := myControl.GetFromUDP(true)
t.Log("Inject msg1 into them, first time")
theirControl.InjectUDPPacket(msg1)
_ = theirControl.GetFromUDP(true)
t.Log("Inject the SAME msg1 again, tests ErrAlreadySeen path")
theirControl.InjectUDPPacket(msg1)
resp2 := theirControl.GetFromUDP(true)
assert.NotNil(t, resp2, "should get cached response on duplicate msg1")
t.Log("Complete handshake with cached response")
myControl.InjectUDPPacket(resp2)
myControl.WaitForType(1, 0, theirControl)
t.Log("Drain cached packet and verify tunnel works")
cachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Verify only one tunnel exists on each side")
assert.Len(t, myControl.ListHostmapHosts(false), 1)
assert.Len(t, theirControl.ListHostmapHosts(false), 1)
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeTruncatedPacketRecovery(t *testing.T) {
t.Parallel()
// Verify that a truncated handshake packet is ignored and the real
// packet can still complete the handshake.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Trigger handshake")
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
t.Log("Get msg1 and deliver to responder")
msg1 := myControl.GetFromUDP(true)
theirControl.InjectUDPPacket(msg1)
t.Log("Get the real response")
realResp := theirControl.GetFromUDP(true)
t.Log("Truncate the response and inject, should be ignored")
truncResp := realResp.Copy()
truncResp.Data = truncResp.Data[:header.Len]
myControl.InjectUDPPacket(truncResp)
t.Log("Verify pending handshake survived the truncated packet")
assert.NotEmpty(t, myControl.ListHostmapHosts(true), "pending handshake should still exist")
t.Log("Inject real response, should complete handshake")
myControl.InjectUDPPacket(realResp)
myControl.WaitForType(1, 0, theirControl)
t.Log("Drain and verify tunnel")
cachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeOrphanedMsg2Dropped(t *testing.T) {
t.Parallel()
// A msg2 arriving with no matching pending index should be silently dropped
// with no response sent and no state changes.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Complete a normal handshake")
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
r.RouteForAllUntilTxTun(theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Record hostmap state")
myIndexes := len(myControl.ListHostmapIndexes(false))
t.Log("Inject a fake msg2 with unknown RemoteIndex")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0xDEADBEEF, 2))
t.Log("Verify no new indexes created")
assert.Equal(t, myIndexes, len(myControl.ListHostmapIndexes(false)))
t.Log("Verify no UDP response was sent")
time.Sleep(100 * time.Millisecond)
assert.Nil(t, myControl.GetFromUDP(false), "should not send a response to orphaned msg2")
t.Log("Verify existing tunnel still works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeUnknownMessageCounter(t *testing.T) {
t.Parallel()
// A handshake packet with an unexpected message counter should be silently
// dropped with no side effects and no UDP response.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
myControl.Start()
theirControl.Start()
t.Log("Inject handshake with MessageCounter=3")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 3))
t.Log("Inject handshake with MessageCounter=99")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 99))
t.Log("Verify no tunnels or pending handshakes")
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
t.Log("Verify no UDP response was sent")
time.Sleep(100 * time.Millisecond)
assert.Nil(t, myControl.GetFromUDP(false))
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeUnknownSubtype(t *testing.T) {
t.Parallel()
// A handshake packet with an unknown subtype should be silently dropped.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.Start()
theirControl.Start()
t.Log("Inject handshake with unknown subtype 99")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.MessageSubType(99), 0, 1))
t.Log("Verify no tunnels or pending handshakes")
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
t.Log("Verify no UDP response was sent")
time.Sleep(100 * time.Millisecond)
assert.Nil(t, myControl.GetFromUDP(false))
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeLateResponse(t *testing.T) {
t.Parallel()
// After a handshake times out, a late response should be silently ignored
// with no new tunnels created.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{
"handshakes": m{
"try_interval": "200ms",
"retries": 2,
},
})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
myControl.Start()
theirControl.Start()
t.Log("Trigger handshake from me")
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
t.Log("Grab msg1 but don't deliver")
msg1 := myControl.GetFromUDP(true)
t.Log("Wait for handshake to time out")
for i := 0; i < 5; i++ {
time.Sleep(300 * time.Millisecond)
myControl.GetFromUDP(false)
}
t.Log("Confirm no pending handshakes remain")
assert.Empty(t, myControl.ListHostmapHosts(true))
t.Log("Deliver old msg1 to them, they create a tunnel")
theirControl.InjectUDPPacket(msg1)
resp := theirControl.GetFromUDP(true)
assert.NotNil(t, resp)
t.Log("Inject late response into me, should be ignored")
myControl.InjectUDPPacket(resp)
t.Log("No tunnel should exist on my side")
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeSelfConnectionRejected(t *testing.T) {
t.Parallel()
// Verify that a node rejects a handshake containing its own VPN IP in the
// peer cert. We do this by sending the initiator's own msg1 back to itself.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
// Need a lighthouse entry to trigger a handshake
myControl.InjectLightHouseAddr(netip.MustParseAddr("10.128.0.2"), netip.MustParseAddrPort("10.0.0.2:4242"))
myControl.Start()
t.Log("Trigger handshake from me")
myControl.InjectTunPacket(BuildTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
msg1 := myControl.GetFromUDP(true)
t.Log("Drain any handshake retransmits before injecting")
time.Sleep(100 * time.Millisecond)
for myControl.GetFromUDP(false) != nil {
}
t.Log("Feed my own msg1 back to me as if it came from someone else")
selfMsg := msg1.Copy()
selfMsg.From = netip.MustParseAddrPort("10.0.0.99:4242")
selfMsg.To = myUdpAddr
myControl.InjectUDPPacket(selfMsg)
t.Log("Verify no response was sent (self-connection rejected)")
time.Sleep(100 * time.Millisecond)
// Drain any further retransmits from the original handshake, then check
// that none of them are a handshake response (MessageCounter=2)
h := &header.H{}
for {
p := myControl.GetFromUDP(false)
if p == nil {
break
}
_ = h.Parse(p.Data)
assert.NotEqual(t, uint64(2), h.MessageCounter,
"should not send a stage 2 response to self-connection")
}
t.Log("Verify no tunnel to myself was created")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false))
myControl.Stop()
}
func TestHandshakeMessageCounter0Dropped(t *testing.T) {
t.Parallel()
// MessageCounter=0 is not a valid handshake message and should be dropped.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
_, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.Start()
t.Log("Inject handshake with MessageCounter=0")
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 0))
time.Sleep(100 * time.Millisecond)
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
assert.Nil(t, myControl.GetFromUDP(false))
myControl.Stop()
}
func TestHandshakeRemoteAllowList(t *testing.T) {
t.Parallel()
// Verify that a handshake from a blocked underlay IP is dropped with no
// response and no state changes. Then verify the same packet from an
// allowed IP succeeds.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{
"lighthouse": m{
"remote_allow_list": m{
"10.0.0.0/8": true,
"0.0.0.0/0": false,
},
},
})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Trigger handshake from them")
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi")))
msg1 := theirControl.GetFromUDP(true)
t.Log("Rewrite the source to a blocked IP and inject")
blockedMsg := msg1.Copy()
blockedMsg.From = netip.MustParseAddrPort("192.168.1.1:4242")
myControl.InjectUDPPacket(blockedMsg)
t.Log("Verify no tunnel, no pending, no response from blocked source")
time.Sleep(100 * time.Millisecond)
assert.Empty(t, myControl.ListHostmapHosts(false))
assert.Empty(t, myControl.ListHostmapHosts(true))
assert.Nil(t, myControl.GetFromUDP(false), "should not respond to blocked source")
t.Log("Now inject the real packet from the allowed source")
myControl.InjectUDPPacket(msg1)
t.Log("Verify handshake completes from allowed source")
resp := myControl.GetFromUDP(true)
assert.NotNil(t, resp)
theirControl.InjectUDPPacket(resp)
theirControl.WaitForType(1, 0, myControl)
t.Log("Drain cached packet and verify tunnel works")
cachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi"), cachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) {
t.Parallel()
// When a duplicate msg1 arrives via ErrAlreadySeen, verify the tunnel
// remains functional and hostmap index count is stable.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
t.Log("Complete a normal handshake via the router")
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
r.RouteForAllUntilTxTun(theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Record hostmap state")
theirIndexes := len(theirControl.ListHostmapIndexes(false))
hi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.NotNil(t, hi)
originalRemote := hi.CurrentRemote
t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)")
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam")))
r.RouteForAllUntilTxTun(theirControl)
t.Log("Verify tunnel still works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Verify remote is still valid and index count is stable")
hi2 := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.NotNil(t, hi2)
assert.Equal(t, originalRemote, hi2.CurrentRemote)
assert.Equal(t, theirIndexes, len(theirControl.ListHostmapIndexes(false)),
"no extra indexes should be created from ErrAlreadySeen")
myControl.Stop()
theirControl.Stop()
}
func TestHandshakeWrongResponderPacketStore(t *testing.T) {
t.Parallel()
// Verify that when the wrong host responds, the cached packets are
// transferred to the new handshake, the evil tunnel is closed, evil's
// address is blocked, and the correct tunnel is eventually established.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
evilControl, evilVpnIpNet, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
r := router.NewR(t, myControl, theirControl, evilControl)
defer r.RenderFlow()
myControl.Start()
theirControl.Start()
evilControl.Start()
t.Log("Send multiple packets to them (cached during handshake)")
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1")))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2")))
t.Log("Route until evil tunnel is closed")
h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
if err := h.Parse(p.Data); err != nil {
panic(err)
}
if h.Type == header.CloseTunnel && p.To == evilUdpAddr {
return router.RouteAndExit
}
return router.KeepRouting
})
t.Log("Verify evil's address is blocked in the new pending handshake")
pendingHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
if pendingHI != nil {
assert.NotContains(t, pendingHI.RemoteAddrs, evilUdpAddr,
"evil's address should be blocked")
}
t.Log("Inject correct lighthouse addr for them")
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
t.Log("Route until cached packets arrive at the real them")
p := r.RouteForAllUntilTxTun(theirControl)
assert.NotNil(t, p, "a cached packet should be delivered to the correct host")
t.Log("Verify the correct host has a tunnel")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Verify no hostinfo artifacts from evil remain")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), true),
"no pending hostinfo for evil")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), false),
"no main hostinfo for evil")
myControl.Stop()
theirControl.Stop()
evilControl.Stop()
}
func TestHandshakeRelayComplete(t *testing.T) {
t.Parallel()
// Verify that a relay handshake completes correctly and relay state is
// properly maintained on all three nodes.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger handshake via relay")
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay")))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Verify bidirectional tunnel via relay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Verify relay state on my side shows relay-to-me")
myHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
assert.NotNil(t, myHI)
assert.NotEmpty(t, myHI.CurrentRelaysToMe, "should have relay-to-me for them")
t.Log("Verify relay state on their side shows relay-to-me")
theirHI := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.NotNil(t, theirHI)
assert.NotEmpty(t, theirHI.CurrentRelaysToMe, "should have relay-to-me for me")
t.Log("Verify relay node shows through-me relays")
relayHI := relayControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.NotNil(t, relayHI)
myControl.Stop()
relayControl.Stop()
theirControl.Stop()
}
// NOTE: Relay V1 cert + IPv6 rejection is not tested here because
// BuildTunUDPPacket from a V4 node to a V6 address panics in the test
// framework. The check is in handshake_manager.go handleOutbound relay
// logic (lines ~304-313): if the relay host has a V1 cert and either
// address is IPv6, the relay is skipped.
// NOTE: Relay reestablishment (Disestablished state transition) is covered
// by the existing TestReestablishRelays in handshakes_test.go.

View File

@@ -11,12 +11,12 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -40,11 +40,22 @@ func BenchmarkHotPath(b *testing.B) {
r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
// Pre-build the IP packet bytes once so the bench measures the data plane,
// not gopacket SerializeLayers overhead.
prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
// EnableFanIn switches the router to a 0-alloc routing path. Required
// for hot-path benchmarks; would conflict with GetFromUDP-using tests.
r.EnableFanIn()
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
myControl.InjectTunPacket(prebuilt)
// Release the TUN-side bytes back to the harness freelist; the bench
// just confirms a packet arrived, the contents aren't inspected.
overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl))
}
myControl.Stop()
@@ -72,11 +83,15 @@ func BenchmarkHotPathRelay(b *testing.B) {
theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
r.EnableFanIn()
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
myControl.InjectTunPacket(prebuilt)
overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl))
}
myControl.Stop()
@@ -85,6 +100,7 @@ func BenchmarkHotPathRelay(b *testing.B) {
}
func TestGoodHandshake(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
@@ -97,7 +113,7 @@ func TestGoodHandshake(t *testing.T) {
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -135,6 +151,7 @@ func TestGoodHandshake(t *testing.T) {
}
func TestGoodHandshakeNoOverlap(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
@@ -170,6 +187,7 @@ func TestGoodHandshakeNoOverlap(t *testing.T) {
}
func TestWrongResponderHandshake(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
@@ -189,7 +207,7 @@ func TestWrongResponderHandshake(t *testing.T) {
evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -246,6 +264,7 @@ func TestWrongResponderHandshake(t *testing.T) {
}
func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
@@ -270,7 +289,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -328,6 +347,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
}
func TestStage1Race(t *testing.T) {
t.Parallel()
// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
// But will eventually collapse down to a single tunnel
@@ -348,8 +368,8 @@ func TestStage1Race(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake to start on both me and them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true)
@@ -408,6 +428,7 @@ func TestStage1Race(t *testing.T) {
}
func TestUncleanShutdownRaceLoser(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
@@ -425,7 +446,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -436,7 +457,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")))
p = r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -457,6 +478,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
}
func TestUncleanShutdownRaceWinner(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
@@ -474,7 +496,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
@@ -486,7 +508,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
@@ -508,6 +530,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
}
func TestRelays(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
@@ -528,7 +551,7 @@ func TestRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -537,6 +560,7 @@ func TestRelays(t *testing.T) {
}
func TestRelaysDontCareAboutIps(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
@@ -557,7 +581,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -566,6 +590,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) {
}
func TestReestablishRelays(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
@@ -586,14 +611,14 @@ func TestReestablishRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Ensure packet traversal from them to me via the relay")
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
p = r.RouteForAllUntilTxTun(myControl)
r.Log("Assert the tunnel works")
@@ -608,7 +633,7 @@ func TestReestablishRelays(t *testing.T) {
for curIndexes >= start {
curIndexes = len(myControl.GetHostmap().Indexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
return router.RouteAndExit
@@ -625,7 +650,7 @@ func TestReestablishRelays(t *testing.T) {
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p = r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -660,7 +685,7 @@ func TestReestablishRelays(t *testing.T) {
t.Log("Assert the tunnel works the other way, too")
for {
t.Log("RouteForAllUntilTxTun")
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
p = r.RouteForAllUntilTxTun(myControl)
r.Log("Assert the tunnel works")
@@ -697,6 +722,7 @@ func TestReestablishRelays(t *testing.T) {
}
func TestStage1RaceRelays(t *testing.T) {
t.Parallel()
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
@@ -729,8 +755,8 @@ func TestStage1RaceRelays(t *testing.T) {
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
r.Log("Wait for a packet from them to me")
p := r.RouteForAllUntilTxTun(myControl)
@@ -744,12 +770,12 @@ func TestStage1RaceRelays(t *testing.T) {
}
func TestStage1RaceRelays2(t *testing.T) {
t.Parallel()
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
l := NewTestLogger()
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
@@ -771,49 +797,41 @@ func TestStage1RaceRelays2(t *testing.T) {
theirControl.Start()
r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
r.Log("Wait for a packet from them to me")
l.Info("Wait for a packet from them to me; myControl")
r.Log("Wait for a packet from them to me; myControl")
r.RouteForAllUntilTxTun(myControl)
l.Info("Wait for a packet from them to me; theirControl")
r.Log("Wait for a packet from them to me; theirControl")
r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels")
l.WithFields(
logrus.Fields{
"myControl": len(myControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
len(myControl.GetHostmap().Indexes),
len(theirControl.GetHostmap().Indexes),
len(relayControl.GetHostmap().Indexes),
)
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
retries := 60
for hostInfos > 6 && retries > 0 {
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
l.WithFields(
logrus.Fields{
"myControl": len(myControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
len(myControl.GetHostmap().Indexes),
len(theirControl.GetHostmap().Indexes),
len(relayControl.GetHostmap().Indexes),
)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
@@ -821,7 +839,6 @@ func TestStage1RaceRelays2(t *testing.T) {
}
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
myControl.Stop()
@@ -830,6 +847,7 @@ func TestStage1RaceRelays2(t *testing.T) {
}
func TestRehandshakingRelays(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
@@ -850,7 +868,7 @@ func TestRehandshakingRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -933,6 +951,7 @@ func TestRehandshakingRelays(t *testing.T) {
}
func TestRehandshakingRelaysPrimary(t *testing.T) {
t.Parallel()
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
@@ -954,7 +973,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
@@ -1037,6 +1056,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
}
func TestRehandshaking(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
@@ -1132,6 +1152,7 @@ func TestRehandshaking(t *testing.T) {
}
func TestRehandshakingLoser(t *testing.T) {
t.Parallel()
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
// Should be the one with the new certificate
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -1230,6 +1251,7 @@ func TestRehandshakingLoser(t *testing.T) {
}
func TestRaceRegression(t *testing.T) {
t.Parallel()
// This test forces stage 1, stage 2, stage 1 to be received by me from them
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
// caused a cross-linked hostinfo
@@ -1253,8 +1275,8 @@ func TestRaceRegression(t *testing.T) {
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
t.Log("Start both handshakes")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
t.Log("Get both stage 1")
myStage1ForThem := myControl.GetFromUDP(true)
@@ -1290,6 +1312,7 @@ func TestRaceRegression(t *testing.T) {
}
func TestV2NonPrimaryWithLighthouse(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}})
@@ -1330,6 +1353,7 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
}
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
@@ -1369,7 +1393,84 @@ func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
theirControl.Stop()
}
func TestLighthouseUpdateOnReload(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
// Create the lighthouse
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh", "10.128.0.1/24", m{"lighthouse": m{"am_lighthouse": true}})
// Create a client with NO lighthouse configured and a long update interval.
// The initial SendUpdate at startup will be a no-op since no lighthouses are known.
myControl, myVpnIpNet, _, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.2/24", m{
"lighthouse": m{
"interval": 600,
"local_allow_list": m{
"10.0.0.0/24": true,
"::/0": false,
},
},
})
r := router.NewR(t, lhControl, myControl)
defer r.RenderFlow()
lhControl.Start()
myControl.Start()
// Drain any startup packets (there should be none meaningful)
r.FlushAll()
// Verify lighthouse has no knowledge of the client
assert.Nil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr()))
// Build a new config that adds the lighthouse
newSettings := make(m)
for k, v := range myConfig.Settings {
newSettings[k] = v
}
newSettings["static_host_map"] = m{
lhVpnIpNet[0].Addr().String(): []any{lhUdpAddr.String()},
}
newSettings["lighthouse"] = m{
"hosts": []any{lhVpnIpNet[0].Addr().String()},
"interval": 600,
"local_allow_list": m{
"10.0.0.0/24": true,
"::/0": false,
},
}
newCfg, err := yaml.Marshal(newSettings)
require.NoError(t, err)
// Reload the config. The lighthouse.hosts change triggers TriggerUpdate,
// which wakes the update worker. It calls SendUpdate, initiating a
// handshake to the new lighthouse and caching the HostUpdateNotification.
require.NoError(t, myConfig.ReloadConfigString(string(newCfg)))
// Route until the lighthouse receives the HostUpdateNotification.
// This covers: handshake stage 1, stage 2, then the cached update.
done := make(chan struct{})
go func() {
r.RouteForAllUntilAfterMsgTypeTo(lhControl, header.LightHouse, 0)
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for lighthouse update after config reload")
}
// Verify lighthouse now has the client's addresses
assert.NotNil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr()))
r.RenderHostmaps("Final hostmaps", lhControl, myControl)
lhControl.Stop()
myControl.Stop()
}
func TestGoodHandshakeUnsafeDest(t *testing.T) {
t.Parallel()
unsafePrefix := "192.168.6.0/24"
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
@@ -1391,7 +1492,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -1419,7 +1520,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
//reply
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")))
//wait for reply
theirControl.WaitForType(1, 0, myControl)
theirCachedPacket := myControl.GetFromTun(true)

View File

@@ -4,7 +4,6 @@
package e2e
import (
"fmt"
"io"
"net/netip"
"os"
@@ -12,15 +11,18 @@ import (
"testing"
"time"
"log/slog"
"dario.cat/mergo"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3"
@@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
"level": l.Level.String(),
"level": testLogLevelName(),
},
"timers": m{
"pending_deletion_interval": 2,
@@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
"level": l.Level.String(),
"level": testLogLevelName(),
},
"timers": m{
"pending_deletion_interval": 2,
@@ -294,12 +294,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
controlB.InjectTunPacket(BuildTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")))
bPacket := r.RouteForAllUntilTxTun(controlA)
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
// And once more from me to them
controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
controlA.InjectTunPacket(BuildTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")))
aPacket := r.RouteForAllUntilTxTun(controlB)
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
}
@@ -379,24 +379,87 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
return a
}
func NewTestLogger() *logrus.Logger {
l := logrus.New()
func NewTestLogger() *slog.Logger {
v := os.Getenv("TEST_LOGS")
if v == "" {
l.SetOutput(io.Discard)
l.SetLevel(logrus.PanicLevel)
return l
return slog.New(slog.NewTextHandler(io.Discard, nil))
}
level := slog.LevelInfo
switch v {
case "2":
l.SetLevel(logrus.DebugLevel)
level = slog.LevelDebug
case "3":
l.SetLevel(logrus.TraceLevel)
default:
l.SetLevel(logrus.InfoLevel)
level = logging.LevelTrace
}
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
}
// testLogLevelName returns the level name string accepted by logging.ApplyConfig
// for the current TEST_LOGS setting. Kept in sync with NewTestLogger.
func testLogLevelName() string {
switch os.Getenv("TEST_LOGS") {
case "2":
return "debug"
case "3":
return "trace"
case "":
return "info"
}
return "info"
}
// BuildTunUDPPacket assembles an IP+UDP packet suitable for Control.InjectTunPacket.
// Using UDP here because it's a simpler protocol.
func BuildTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) []byte {
serialize := make([]gopacket.SerializableLayer, 0)
var netLayer gopacket.NetworkLayer
if toAddr.Is6() {
if !fromAddr.Is6() {
panic("Cant send ipv6 to ipv4")
}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
}
return l
udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
if err := udp.SetNetworkLayerForChecksum(netLayer); err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
serialize = append(serialize, &udp, gopacket.Payload(data))
if err := gopacket.SerializeLayers(buffer, opt, serialize...); err != nil {
panic(err)
}
return buffer.Bytes()
}

51
e2e/leak_test.go Normal file
View File

@@ -0,0 +1,51 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"go.uber.org/goleak"
)
// TestNoGoroutineLeaks brings up two nebula instances, completes a tunnel,
// stops both, and asserts no goroutines leak past the shutdown. goleak's
// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain
// before failing the assertion.
//
// IgnoreCurrent is necessary in the parallelized suite: other tests can
// leave goroutines mid-shutdown when this one runs (Stop is async, the
// wg.Wait() drain is not blocking on test return). We're checking that
// *this* test's setup tears down cleanly, not that the whole suite is
// idle at this moment. Intentionally NOT t.Parallel()'d for the same
// reason — concurrent test goroutines would always show up.
func TestNoGoroutineLeaks(t *testing.T) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
r.RenderFlow()
// Settle period: Stop() is non-blocking; the wg-driven goroutines need
// a moment to drain. goleak retries internally too, but a short explicit
// settle reduces flakes when the suite is busy.
time.Sleep(50 * time.Millisecond)
}

View File

@@ -13,6 +13,7 @@ import (
"regexp"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
@@ -24,6 +25,19 @@ import (
"golang.org/x/exp/maps"
)
// outNatKey is the (from, to) pair used by outNat. Comparable struct, so it works as a map key without the
// allocation cost of a string-concat key.
type outNatKey struct {
from, to netip.AddrPort
}
// fannedPacket pairs a UDP TX packet with its source control so the router can route it after popping from
// the fan-in channel.
type fannedPacket struct {
from *nebula.Control
pkt *udp.Packet
}
type R struct {
// Simple map of the ip:port registered on a control to the control
// Basically a router, right?
@@ -34,12 +48,28 @@ type R struct {
// A last used map, if an inbound packet hit the inNat map then
// all return packets should use the same last used inbound address for the outbound sender
// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
outNat map[string]netip.AddrPort
outNat map[outNatKey]netip.AddrPort
// A map of vpn ip to the nebula control it belongs to
vpnControls map[netip.Addr]*nebula.Control
// Cached select infrastructure for RouteForAllUntilTxTun.
// The controls map is immutable after NewR so the cases are good for the test lifetime.
// We only rebuild if a different receiver is asked.
selRecvCtl *nebula.Control
selCases []reflect.SelectCase
selCtls []*nebula.Control
// Optional fan-in mode for hot-path benchmarks: one forwarder goroutine per control drains UDP TX into udpFanIn,
// so RouteForAllUntilTxTun can do a fixed 2-way native select instead of paying reflect.Select per call.
// Off by default (would otherwise interleave with tests that use GetFromUDP directly on the same control).
// Enabled by EnableFanIn.
udpFanIn chan fannedPacket
stopFanIn chan struct{}
fanInWG sync.WaitGroup
fanInMu sync.Mutex
fanInOn atomic.Bool
ignoreFlows []ignoreFlow
flow []flowEntry
@@ -119,7 +149,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
controls: make(map[netip.AddrPort]*nebula.Control),
vpnControls: make(map[netip.Addr]*nebula.Control),
inNat: make(map[netip.AddrPort]*nebula.Control),
outNat: make(map[string]netip.AddrPort),
outNat: make(map[outNatKey]netip.AddrPort),
flow: []flowEntry{},
ignoreFlows: []ignoreFlow{},
fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
@@ -153,8 +183,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
case <-ctx.Done():
return
case <-clockSource.C:
r.Lock()
r.renderHostmaps("clock tick")
r.renderFlow()
r.Unlock()
}
}
}()
@@ -180,15 +212,21 @@ func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
// RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening.
func (r *R) RenderFlow() {
r.cancelRender()
r.Lock()
defer r.Unlock()
r.renderFlow()
}
// CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected
func (r *R) CancelFlowLogs() {
r.cancelRender()
r.Lock()
r.flow = nil
r.Unlock()
}
// renderFlow writes the flow log to disk. Caller must hold r.Lock. renderFlow reads r.flow / r.additionalGraphs and
// the *packet pointers stashed inside, all of which are mutated under the same lock by routing paths.
func (r *R) renderFlow() {
if r.flow == nil {
return
@@ -434,68 +472,157 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
panic("No control for udp tx " + a.String())
}
fp := r.unlockedInjectFlow(sender, c, p, false)
c.InjectUDPPacket(p)
c.InjectUDPPacket(p) // copies internally; original is ours to release
fp.WasReceived()
r.Unlock()
p.Release()
}
}
}
// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on receivers tun
// If the router doesn't have the nebula controller for that address, we panic
// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on the receiver's tun.
// If a control's UDP TX address can't be matched to a registered control, we panic.
//
// For allocation-sensitive callers (hot-path benchmarks, in particular relay
// benches with 3+ controls), call EnableFanIn() first.
func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
if r.fanInOn.Load() {
return r.routeFanIn(receiver)
}
return r.routeReflect(receiver)
}
// routeFanIn is the alloc-free path used when EnableFanIn is in effect.
func (r *R) routeFanIn(receiver *nebula.Control) []byte {
tunTx := receiver.GetTunTxChan()
for {
select {
case p := <-tunTx:
r.Lock()
if r.flow != nil {
np := udp.Packet{Data: make([]byte, len(p))}
copy(np.Data, p)
r.unlockedInjectFlow(receiver, receiver, &np, true)
}
r.Unlock()
return p
case fp := <-r.udpFanIn:
r.routeUDP(fp.from, fp.pkt)
}
}
}
// routeReflect is the default reflect.Select-based path. Pays the boxing allocation per call but doesn't interfere
// with tests that pull packets directly from controls' UDP TX channels via GetFromUDP.
func (r *R) routeReflect(receiver *nebula.Control) []byte {
sc, cm := r.selectCasesFor(receiver)
for {
x, rx, _ := reflect.Select(sc)
if x == 0 {
p := rx.Interface().([]byte)
r.Lock()
if r.flow != nil {
np := udp.Packet{Data: make([]byte, len(p))}
copy(np.Data, p)
r.unlockedInjectFlow(cm[x], cm[x], &np, true)
}
r.Unlock()
return p
}
r.routeUDP(cm[x], rx.Interface().(*udp.Packet))
}
}
// EnableFanIn switches RouteForAllUntilTxTun to the alloc-free fan-in path.
// One forwarder goroutine per registered control drains UDP TX into a shared channel that RouteForAllUntilTxTun selects
// on alongside the receiver's TUN TX channel.
func (r *R) EnableFanIn() {
r.fanInMu.Lock()
defer r.fanInMu.Unlock()
if r.fanInOn.Load() {
return
}
r.udpFanIn = make(chan fannedPacket, 32)
r.stopFanIn = make(chan struct{})
for _, c := range r.controls {
r.startFanInWorker(c)
}
r.fanInOn.Store(true)
r.t.Cleanup(r.stopFanInWorkers)
}
// startFanInWorker spawns a goroutine that drains c's UDP TX into r.udpFanIn.
func (r *R) startFanInWorker(c *nebula.Control) {
r.fanInWG.Add(1)
udpTx := c.GetUDPTxChan()
go func() {
defer r.fanInWG.Done()
for {
select {
case <-r.stopFanIn:
return
case p := <-udpTx:
select {
case <-r.stopFanIn:
p.Release()
return
case r.udpFanIn <- fannedPacket{from: c, pkt: p}:
}
}
}
}()
}
// stopFanInWorkers signals the fan-in goroutines to exit and waits for them.
func (r *R) stopFanInWorkers() {
r.fanInMu.Lock()
wasOn := r.fanInOn.Swap(false)
r.fanInMu.Unlock()
if !wasOn {
return
}
close(r.stopFanIn)
r.fanInWG.Wait()
}
// routeUDP forwards a UDP TX packet from the named source control to the destination control derived from p.To,
// releasing the source packet after InjectUDPPacket has copied its bytes into a fresh pool slot.
func (r *R) routeUDP(from *nebula.Control, p *udp.Packet) {
r.Lock()
defer r.Unlock()
a := from.GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
panic(fmt.Sprintf("No control for udp tx %s", p.To))
}
fp := r.unlockedInjectFlow(from, c, p, false)
c.InjectUDPPacket(p) // copies internally; original is ours to release
fp.WasReceived()
p.Release()
}
// selectCasesFor returns the SelectCase array used by routeReflect: one slot for the receiver's TUN TX channel followed
// by one per control's UDP TX channel. Cached for the test lifetime, only rebuilt if the receiver changes.
func (r *R) selectCasesFor(receiver *nebula.Control) ([]reflect.SelectCase, []*nebula.Control) {
r.Lock()
defer r.Unlock()
if r.selRecvCtl == receiver && r.selCases != nil {
return r.selCases, r.selCtls
}
sc := make([]reflect.SelectCase, len(r.controls)+1)
cm := make([]*nebula.Control, len(r.controls)+1)
i := 0
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(receiver.GetTunTxChan()),
Send: reflect.Value{},
}
cm[i] = receiver
i++
sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(receiver.GetTunTxChan())}
cm[0] = receiver
i := 1
for _, c := range r.controls {
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(c.GetUDPTxChan()),
Send: reflect.Value{},
}
sc[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan())}
cm[i] = c
i++
}
for {
x, rx, _ := reflect.Select(sc)
r.Lock()
if x == 0 {
// we are the tun tx, we can exit
p := rx.Interface().([]byte)
np := udp.Packet{Data: make([]byte, len(p))}
copy(np.Data, p)
r.unlockedInjectFlow(cm[x], cm[x], &np, true)
r.Unlock()
return p
} else {
// we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet)
a := cm[x].GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
r.Unlock()
panic(fmt.Sprintf("No control for udp tx %s", p.To))
}
fp := r.unlockedInjectFlow(cm[x], c, p, false)
c.InjectUDPPacket(p)
fp.WasReceived()
}
r.Unlock()
}
r.selRecvCtl = receiver
r.selCases = sc
r.selCtls = cm
return sc, cm
}
// RouteExitFunc will call the whatDo func with each udp packet from sender.
@@ -522,6 +649,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
switch e {
case ExitNow:
r.Unlock()
p.Release()
return
case RouteAndExit:
@@ -529,6 +657,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
receiver.InjectUDPPacket(p)
fp.WasReceived()
r.Unlock()
p.Release()
return
case KeepRouting:
@@ -541,6 +670,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
}
r.Unlock()
p.Release()
}
}
@@ -641,6 +771,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
switch e {
case ExitNow:
r.Unlock()
p.Release()
return
case RouteAndExit:
@@ -648,6 +779,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
receiver.InjectUDPPacket(p)
fp.WasReceived()
r.Unlock()
p.Release()
return
case KeepRouting:
@@ -659,6 +791,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
}
r.Unlock()
p.Release()
}
}
@@ -702,19 +835,20 @@ func (r *R) FlushAll() {
}
receiver.InjectUDPPacket(p)
r.Unlock()
p.Release()
}
}
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock
func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
if newAddr, ok := r.outNat[outNatKey{from: fromAddr, to: toAddr}]; ok {
p.From = newAddr
}
c, ok := r.inNat[toAddr]
if ok {
r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
r.outNat[outNatKey{from: c.GetUDPAddr(), to: fromAddr}] = toAddr
return c
}

View File

@@ -12,11 +12,14 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestDropInactiveTunnels(t *testing.T) {
t.Parallel()
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -61,6 +64,7 @@ func TestDropInactiveTunnels(t *testing.T) {
}
func TestCertUpgrade(t *testing.T) {
t.Parallel()
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -155,6 +159,7 @@ func TestCertUpgrade(t *testing.T) {
}
func TestCertDowngrade(t *testing.T) {
t.Parallel()
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -253,6 +258,7 @@ func TestCertDowngrade(t *testing.T) {
}
func TestCertMismatchCorrection(t *testing.T) {
t.Parallel()
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -320,6 +326,7 @@ func TestCertMismatchCorrection(t *testing.T) {
}
func TestCrossStackRelaysWork(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
@@ -348,14 +355,14 @@ func TestCrossStackRelaysWork(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
@@ -365,3 +372,107 @@ func TestCrossStackRelaysWork(t *testing.T) {
//theirControl.Stop()
//relayControl.Stop()
}
func TestCloseTunnelAuthenticated(t *testing.T) {
t.Parallel()
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Close the tunnel")
myControl.CloseTunnel(theirVpnIpNet[0].Addr(), false)
r.FlushAll()
waitStart := time.Now()
for {
myIndexes := len(myControl.GetHostmap().Indexes)
theirIndexes := len(theirControl.GetHostmap().Indexes)
if myIndexes == 0 && theirIndexes == 0 {
break
}
since := time.Since(waitStart)
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
if since > time.Second*6 {
t.Fatal("Tunnel should have been declared inactive after 2 seconds and before 6 seconds")
}
time.Sleep(1 * time.Second)
//r.FlushAll()
}
r.Logf("Happy path success, tunnels were dropped within %v", time.Since(waitStart))
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
r.Log("Assert another tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
hi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if hi == nil {
t.Fatal("There is no hostinfo for this tunnel")
}
myHi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if myHi == nil {
t.Fatal("There is no hostinfo for my tunnel")
}
r.Log("It does")
buf := make([]byte, 1024)
hdr := header.H{
Version: 1,
Type: header.CloseTunnel,
Subtype: 0,
Reserved: 0,
RemoteIndex: hi.RemoteIndex,
MessageCounter: 5,
}
out, err := hdr.Encode(buf)
if err != nil {
t.Fatal(err)
}
pkt := &udp.Packet{
To: hi.CurrentRemote,
From: myHi.CurrentRemote,
Data: out,
}
r.InjectUDPPacket(myControl, theirControl, pkt)
r.Log("Injected bogus close tunnel. Let's see!")
waitStart = time.Now()
for {
myIndexes := len(myControl.GetHostmap().Indexes)
theirIndexes := len(theirControl.GetHostmap().Indexes)
if myIndexes == 0 {
t.Fatal("myIndexes should not be 0")
}
if theirIndexes == 0 {
t.Fatal("theirIndexes should not be 0, they should have rejected this bogus packet")
}
since := time.Since(waitStart)
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
if since > time.Second*4 {
t.Log("The tunnel would have been gone by now")
break
}
time.Sleep(1 * time.Second)
r.FlushAll()
}
myControl.Stop()
theirControl.Stop()
}

View File

@@ -204,6 +204,12 @@ punchy:
# Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access.
#trusted_cas:
#- "ssh public key string"
# sandbox_dir restricts file paths for profiling commands (start-cpu-profile, save-heap-profile,
# save-mutex-profile) to the specified directory. Relative paths will be resolved within this directory,
# and absolute paths outside of it will be rejected. Default is $TMP/nebula-debug.
# The directory is NOT automatically created.
# Overriding this to "" is the same as "/" and will allow overwriting any path on the host.
#sandbox_dir: /var/tmp/nebula-debug
# EXPERIMENTAL: relay support for networks that can't establish direct connections.
relay:
@@ -327,24 +333,21 @@ tun:
# Configure logging level
logging:
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
#NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some
# scenarios. Debug logging is also CPU intensive and will decrease performance overall.
# Only enable debug logging while actively investigating an issue.
# trace, debug, info, warn, or error. Default is info and is reloadable.
# fatal and panic are accepted for backwards compatibility and map to error.
#NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some
# scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall.
# Only enable debug or trace logging while actively investigating an issue.
level: info
# json or text formats currently available. Default is text
# json or text formats currently available. Default is text.
format: text
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
# Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false.
#disable_timestamp: true
# timestamp format is specified in Go time format, see:
# https://golang.org/pkg/time/#pkg-constants
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
# default when `format: text`:
# when TTY attached: seconds since beginning of execution
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
# As an example, to log as RFC3339 with millisecond precision, set to:
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
# Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable.
# The stats section is reloadable. A HUP may change the backend, toggle stats
# on or off, switch the listen/host address, or pick up new DNS for the
# configured graphite host.
#stats:
#type: graphite
#prefix: nebula
@@ -362,10 +365,12 @@ logging:
# enables counter metrics for meta packets
# e.g.: `messages.tx.handshake`
# NOTE: `message.{tx,rx}.recv_error` is always emitted
# Not reloadable.
#message_metrics: false
# enables detailed counter metrics for lighthouse packets
# e.g.: `lighthouse.rx.HostQuery`
# Not reloadable.
#lighthouse_metrics: false
# Handshake Manager Settings
@@ -423,8 +428,8 @@ firewall:
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
# proto: `any`, `tcp`, `udp`, or `icmp`
# a port specification is ignored if proto is `icmp`
# host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass

View File

@@ -7,9 +7,9 @@ import (
"net"
"os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/service"
)
@@ -64,8 +64,7 @@ pki:
return err
}
logger := logrus.New()
logger.Out = os.Stdout
logger := logging.NewLogger(os.Stdout)
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {

View File

@@ -1,11 +1,13 @@
package nebula
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash/fnv"
"log/slog"
"net/netip"
"reflect"
"slices"
@@ -16,7 +18,6 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
@@ -67,7 +68,7 @@ type Firewall struct {
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
l *logrus.Logger
l *slog.Logger
}
type firewallMetrics struct {
@@ -131,7 +132,7 @@ type firewallLocalCIDR struct {
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration
var tmin, tmax time.Duration
@@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
@@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
case "drop":
fw.InSendReject = false
default:
l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`")
l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction)
fw.InSendReject = false
}
@@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
case "drop":
fw.OutSendReject = false
default:
l.WithField("action", inboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`")
l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction)
fw.OutSendReject = false
}
@@ -249,20 +250,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
)
f.rules += ruleString + "\n"
direction := "incoming"
if !incoming {
direction = "outgoing"
}
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
ft *FirewallTable
fp firewallPort
@@ -280,6 +267,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
case firewall.ProtoUDP:
fp = ft.UDP
case firewall.ProtoICMP, firewall.ProtoICMPv6:
//ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided
if startPort != firewall.PortAny {
f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort)
}
startPort = firewall.PortAny
endPort = firewall.PortAny
fp = ft.ICMP
case firewall.ProtoAny:
fp = ft.AnyProto
@@ -287,6 +280,21 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto)
}
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
)
f.rules += ruleString + "\n"
direction := "incoming"
if !incoming {
direction = "outgoing"
}
f.l.Info("Firewall rule added",
"firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha},
)
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
}
@@ -308,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
}
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string
if inbound {
table = "firewall.inbound"
@@ -349,24 +357,31 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
sPort = r.Port
}
startPort, endPort, err := parsePort(sPort)
if err != nil {
return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
}
var proto uint8
var startPort, endPort int32
switch r.Proto {
case "any":
proto = firewall.ProtoAny
startPort, endPort, err = parsePort(sPort)
case "tcp":
proto = firewall.ProtoTCP
startPort, endPort, err = parsePort(sPort)
case "udp":
proto = firewall.ProtoUDP
startPort, endPort, err = parsePort(sPort)
case "icmp":
proto = firewall.ProtoICMP
startPort = firewall.PortAny
endPort = firewall.PortAny
if sPort != "" {
l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort)
}
default:
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
if err != nil {
return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
}
if r.Cidr != "" && r.Cidr != "any" {
_, err = netip.ParsePrefix(r.Cidr)
@@ -383,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
}
if warning := r.sanity(); warning != nil {
l.Warnf("%s rule #%v; %s", table, i, warning)
l.Warn("firewall rule sanity check",
"table", table,
"rule", i,
"warning", warning,
)
}
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
@@ -467,7 +486,7 @@ func (f *Firewall) metrics(incoming bool) firewallMetrics {
}
}
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
// Destroy cleans up any known cyclical references so the object can be freed by GC. This should be called if a new
// firewall object is created
func (f *Firewall) Destroy() {
//TODO: clean references if/when needed
@@ -515,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
// We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset",
"fwPacket", fp,
"incoming", c.incoming,
"rulesVersion", f.rulesVersion,
"oldRulesVersion", c.rulesVersion,
)
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
}
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset",
"fwPacket", fp,
"incoming", c.incoming,
"rulesVersion", f.rulesVersion,
"oldRulesVersion", c.rulesVersion,
)
}
c.rulesVersion = f.rulesVersion
@@ -660,6 +679,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
return false
}
// this branch is here to catch traffic from FirewallTable.Any.match and FirewallTable.ICMP.match
if p.Protocol == firewall.ProtoICMP || p.Protocol == firewall.ProtoICMPv6 {
// port numbers are re-used for connection tracking of ICMP,
// but we don't want to actually filter on them.
return fp[firewall.PortAny].match(p, c, caPool)
}
var port int32
if p.Fragment {
@@ -804,10 +830,8 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
return true
}
for _, group := range groups {
if group == "any" {
return true
}
if slices.Contains(groups, "any") {
return true
}
if host == "any" {
@@ -917,7 +941,7 @@ type rule struct {
CASha string
}
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) {
r := rule{}
m, ok := p.(map[string]any)
@@ -948,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
}
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
l.Warn("group was an array with a single value, converting to simple value",
"table", table,
"rule", i,
)
m["group"] = v[0]
}
@@ -1018,54 +1045,56 @@ func (r *rule) sanity() error {
}
}
if r.Code != "" {
return fmt.Errorf("code specified as [%s]. Support for 'code' will be dropped in a future release, as it has never been functional", r.Code)
}
//todo alert on cidr-any
return nil
}
func parsePort(s string) (startPort, endPort int32, err error) {
func parsePort(s string) (int32, int32, error) {
var err error
const notAPort int32 = -2
if s == "any" {
startPort = firewall.PortAny
endPort = firewall.PortAny
} else if s == "fragment" {
startPort = firewall.PortFragment
endPort = firewall.PortFragment
} else if strings.Contains(s, `-`) {
sPorts := strings.SplitN(s, `-`, 2)
sPorts[0] = strings.Trim(sPorts[0], " ")
sPorts[1] = strings.Trim(sPorts[1], " ")
if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
}
rStartPort, err := strconv.Atoi(sPorts[0])
if err != nil {
return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
}
rEndPort, err := strconv.Atoi(sPorts[1])
if err != nil {
return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
}
startPort = int32(rStartPort)
endPort = int32(rEndPort)
if startPort == firewall.PortAny {
endPort = firewall.PortAny
}
} else {
return firewall.PortAny, firewall.PortAny, nil
}
if s == "fragment" {
return firewall.PortFragment, firewall.PortFragment, nil
}
if !strings.Contains(s, `-`) {
rPort, err := strconv.Atoi(s)
if err != nil {
return 0, 0, fmt.Errorf("was not a number; `%s`", s)
return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s)
}
startPort = int32(rPort)
endPort = startPort
return int32(rPort), int32(rPort), nil
}
return
sPorts := strings.SplitN(s, `-`, 2)
for i := range sPorts {
sPorts[i] = strings.Trim(sPorts[i], " ")
}
if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
}
rStartPort, err := strconv.Atoi(sPorts[0])
if err != nil {
return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
}
rEndPort, err := strconv.Atoi(sPorts[1])
if err != nil {
return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
}
startPort := int32(rStartPort)
endPort := int32(rEndPort)
if startPort == firewall.PortAny {
endPort = firewall.PortAny
}
return startPort, endPort, nil
}

View File

@@ -1,10 +1,10 @@
package firewall
import (
"context"
"log/slog"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
)
// ConntrackCache is used as a local routine cache to know if a given flow
@@ -15,41 +15,49 @@ type ConntrackCacheTicker struct {
cacheV uint64
cacheTick atomic.Uint64
l *slog.Logger
cache ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
l: l,
cache: ConntrackCache{},
}
go c.tick(d)
go c.tick(ctx, d)
return c
}
func (c *ConntrackCacheTicker) tick(d time.Duration) {
func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
t := time.NewTicker(d)
defer t.Stop()
for {
time.Sleep(d)
c.cacheTick.Add(1)
select {
case <-ctx.Done():
return
case <-t.C:
c.cacheTick.Add(1)
}
}
}
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
func (c *ConntrackCacheTicker) Get() ConntrackCache {
if c == nil {
return nil
}
if tick := c.cacheTick.Load(); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
if c.l.Enabled(context.Background(), slog.LevelDebug) {
c.l.Debug("resetting conntrack cache", "len", ll)
}
c.cache = make(ConntrackCache, ll)
}

69
firewall/cache_test.go Normal file
View File

@@ -0,0 +1,69 @@
package firewall
import (
"bytes"
"log/slog"
"strings"
"testing"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
// The tests below pin the log format produced by ConntrackCacheTicker.Get
// so changes cannot silently break what operators are grepping for. The
// ticker's internal state (cache + cacheTick) is poked directly to avoid
// racing a goroutine-driven tick in tests.
func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker {
t.Helper()
c := &ConntrackCacheTicker{
l: l,
cache: make(ConntrackCache, cacheLen),
}
for i := 0; i < cacheLen; i++ {
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
}
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
return c
}
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 3)
c.Get()
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String())
}
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 2)
c.Get()
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
}
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo)
c := newFixedTicker(t, l, 5)
c.Get()
assert.Empty(t, buf.String())
}
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 0)
c.Get()
assert.Empty(t, buf.String())
}

View File

@@ -23,7 +23,10 @@ const (
type Packet struct {
LocalAddr netip.Addr
RemoteAddr netip.Addr
LocalPort uint16
// LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP.
LocalPort uint16
// RemotePort is the source port for incoming traffic, or the destination port for outgoing.
// For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier
RemotePort uint16
Protocol uint8
Fragment bool
@@ -47,6 +50,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
proto = "tcp"
case ProtoICMP:
proto = "icmp"
case ProtoICMPv6:
proto = "icmpv6"
case ProtoUDP:
proto = "udp"
default:

View File

@@ -3,13 +3,13 @@ package nebula
import (
"bytes"
"errors"
"log/slog"
"math"
"net/netip"
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
@@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) {
}
func TestFirewall_AddRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@@ -87,9 +86,10 @@ func TestFirewall_AddRule(t *testing.T) {
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
//no matter what port is given for icmp, it should end up as "any"
assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any)
assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
@@ -176,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) {
}
func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
@@ -253,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) {
}
func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
@@ -484,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -543,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) {
}
func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -632,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) {
}
func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
@@ -670,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -734,10 +728,152 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
func TestFirewall_ICMPPortBehavior(t *testing.T) {
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
network := netip.MustParsePrefix("1.2.3.4/24")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host1",
networks: []netip.Prefix{network},
groups: []string{"default-group"},
issuer: "signer-shasum",
},
InvertedGroups: map[string]struct{}{"default-group": {}},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
cp := cert.NewCAPool()
templ := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
Protocol: firewall.ProtoICMP,
Fragment: false,
}
t.Run("ICMP allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
})
t.Run("nonzero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
})
})
t.Run("Any proto, some ports allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 80
p.RemotePort = 80
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
})
t.Run("Any proto, any port", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, allowed", func(t *testing.T) {
resetConntrack(fw)
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
})
t.Run("nonzero ports, allowed", func(t *testing.T) {
resetConntrack(fw)
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
//different ID is blocked
p.RemotePort++
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
})
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
ob := &bytes.Buffer{}
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
@@ -897,56 +1033,56 @@ func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger()
// Test a bad rule definition
c := &dummyCert{}
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil, "aes")
require.NoError(t, err)
conf := config.NewC(l)
conf := config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
@@ -955,28 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) {
func TestAddFirewallRulesFromConfig(t *testing.T) {
l := test.NewLogger()
// Test adding tcp rule
conf := config.NewC(l)
conf := config.NewC(test.NewLogger())
mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule no port
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
@@ -984,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
@@ -999,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
// Test Add error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
@@ -1082,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
}
func TestFirewall_convertRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
// Ensure group array of 1 is converted and a warning is printed
c := map[string]any{
@@ -1092,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) {
}
r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value")
assert.Contains(t, ob.String(), "table=test")
assert.Contains(t, ob.String(), "rule=1")
require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups)
@@ -1118,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) {
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
@@ -1234,7 +1377,7 @@ type testsetup struct {
fw *Firewall
}
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{
name: "me",
networks: myPrefixes,
@@ -1245,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
return newSetupFromCert(t, l, c)
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
@@ -1262,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out

25
go.mod
View File

@@ -1,9 +1,10 @@
module github.com/slackhq/nebula
go 1.25
go 1.25.0
require (
dario.cat/mergo v1.0.2
filippo.io/bigmod v0.1.0
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
@@ -12,26 +13,26 @@ require (
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4
github.com/miekg/dns v1.1.70
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/miekg/dns v1.1.72
github.com/miekg/pkcs11 v1.1.2
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.23.2
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.4
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1
go.uber.org/goleak v1.3.0
go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.47.0
golang.org/x/crypto v0.50.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/sys v0.40.0
golang.org/x/term v0.39.0
golang.org/x/net v0.52.0
golang.org/x/sync v0.20.0
golang.org/x/sys v0.43.0
golang.org/x/term v0.42.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
golang.zx2c4.com/wireguard/windows v0.6.1
google.golang.org/protobuf v1.36.11
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
@@ -49,7 +50,7 @@ require (
github.com/prometheus/procfs v0.16.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.40.0 // indirect
golang.org/x/tools v0.43.0 // indirect
)

44
go.sum
View File

@@ -1,6 +1,8 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
filippo.io/bigmod v0.1.0 h1:UNzDk7y9ADKST+axd9skUpBQeW7fG2KrTZyOE4uGQy8=
filippo.io/bigmod v0.1.0/go.mod h1:OjOXDNlClLblvXdwgFFOQFJEocLhhtai8vGLy0JCZlI=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -83,10 +85,10 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA=
github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/miekg/pkcs11 v1.1.2 h1:/VxmeAX5qU6Q3EwafypogwWbYryHFmF2RpkJmw3m4MQ=
github.com/miekg/pkcs11 v1.1.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@@ -131,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
@@ -162,16 +162,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -191,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -223,8 +223,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -233,8 +233,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
golang.zx2c4.com/wireguard/windows v0.6.1 h1:XMaKojH1Hs/raMrmnir4n35nTvzvWj7NmSYzHn2F4qU=
golang.zx2c4.com/wireguard/windows v0.6.1/go.mod h1:04aqInu5GYuTFvMuDw/rKBAF7mHrltW/3rekpfbbZDM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=

57
handshake/credential.go Normal file
View File

@@ -0,0 +1,57 @@
package handshake
import (
"crypto/rand"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
)
// Credential holds everything needed to participate in a handshake
// at a given cert version. Version and Curve are read from Cert; the public
// half of the static keypair likewise comes from Cert.PublicKey().
type Credential struct {
Cert cert.Certificate // the certificate
Bytes []byte // pre-marshaled certificate bytes
privateKey []byte // static private key (public half lives in Cert)
cipherSuite noise.CipherSuite // pre-built cipher suite (DH + cipher + hash)
}
// NewCredential creates a Credential with all material needed for handshake
// participation. The cipherSuite should be pre-built by the caller with the
// appropriate DH function, cipher, and hash.
func NewCredential(
c cert.Certificate,
hsBytes []byte,
privateKey []byte,
cipherSuite noise.CipherSuite,
) *Credential {
return &Credential{
Cert: c,
Bytes: hsBytes,
privateKey: privateKey,
cipherSuite: cipherSuite,
}
}
// buildHandshakeState creates a noise.HandshakeState from this credential.
func (hc *Credential) buildHandshakeState(initiator bool, pattern noise.HandshakePattern) (*noise.HandshakeState, error) {
return noise.NewHandshakeState(noise.Config{
CipherSuite: hc.cipherSuite,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: noise.DHKey{Private: hc.privateKey, Public: hc.Cert.PublicKey()},
PresharedKey: []byte{},
PresharedKeyPlacement: 0,
})
}
// GetCredentialFunc returns the handshake credential for the given version,
// or nil if that version is not available.
//
// Implementations must return credentials drawn from a snapshot stable for
// the lifetime of any single Machine. The Machine may call this multiple
// times during a handshake (e.g. when negotiating to the peer's version)
// and assumes the underlying static keypair is consistent across calls.
type GetCredentialFunc func(v cert.Version) *Credential

21
handshake/errors.go Normal file
View File

@@ -0,0 +1,21 @@
package handshake
import "errors"
var (
ErrInitiateOnResponder = errors.New("initiate called on responder")
ErrInitiateAlreadyCalled = errors.New("initiate already called")
ErrInitiateNotCalled = errors.New("initiate must be called before ProcessPacket for initiators")
ErrPacketTooShort = errors.New("packet too short")
ErrPublicKeyMismatch = errors.New("public key mismatch between certificate and handshake")
ErrIncompleteHandshake = errors.New("handshake completed without receiving required content")
ErrMachineFailed = errors.New("handshake machine has failed")
ErrUnknownSubtype = errors.New("unknown handshake subtype")
ErrMissingContent = errors.New("expected handshake content but message was empty")
ErrUnexpectedContent = errors.New("received unexpected handshake content")
ErrIndexAllocation = errors.New("failed to allocate local index")
ErrNoCredential = errors.New("no handshake credential available for cert version")
ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key")
ErrMultiMessageUnsupported = errors.New("multi-message handshake patterns are not yet supported by the manager")
ErrSubtypeMismatch = errors.New("packet subtype does not match handshake machine subtype")
)

37
handshake/handshake.proto Normal file
View File

@@ -0,0 +1,37 @@
// This file documents the wire format the nebula handshake speaks. It is
// not run through protoc; the encoder/decoder in payload.go is hand-written
// against this shape directly to keep the parser narrow and panic-free.
//
// Any change to the wire format must be reflected here, and adding a new
// field requires updating MarshalPayload / unmarshalPayloadDetails together
// with the field-uniqueness and wire-type checks in those functions.
syntax = "proto3";
package nebula.handshake;
message NebulaHandshake {
NebulaHandshakeDetails Details = 1;
bytes Hmac = 2;
}
message NebulaHandshakeDetails {
bytes Cert = 1;
uint32 InitiatorIndex = 2;
uint32 ResponderIndex = 3;
// Cookie was reserved for an anti-DoS mechanism that was never
// implemented. No released version of nebula has ever populated it; the
// hand-written parser silently skips it on read.
uint64 Cookie = 4 [deprecated = true];
uint64 Time = 5;
uint32 CertVersion = 8;
MultiPortDetails InitiatorMultiPort = 6;
MultiPortDetails ResponderMultiPort = 7;
}
message MultiPortDetails {
bool RxSupported = 1;
bool TxSupported = 2;
uint32 BasePort = 3;
uint32 TotalPorts = 4;
}

116
handshake/helpers_test.go Normal file
View File

@@ -0,0 +1,116 @@
package handshake
import (
"net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
ct "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/header"
"github.com/stretchr/testify/require"
)
// testCertState holds cert material for a test peer.
type testCertState struct {
version cert.Version
creds map[cert.Version]*Credential
}
func (s *testCertState) getCredential(v cert.Version) *Credential {
return s.creds[v]
}
func newTestCertState(
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
) *testCertState {
return newTestCertStateWithCipher(t, ca, caKey, name, networks, noise.CipherChaChaPoly)
}
func newTestCertStateWithCipher(
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
cipher noise.CipherFunc,
) *testCertState {
t.Helper()
c, _, rawPrivKey, _ := ct.NewTestCert(
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
)
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawPrivKey)
require.NoError(t, err)
hsBytes, err := c.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, cipher, noise.HashSHA256)
return &testCertState{
version: cert.Version2,
creds: map[cert.Version]*Credential{
cert.Version2: NewCredential(c, hsBytes, priv, ncs),
},
}
}
func testVerifier(pool *cert.CAPool) CertVerifier {
return func(c cert.Certificate) (*cert.CachedCertificate, error) {
return pool.VerifyCertificate(time.Now(), c)
}
}
func newTestMachine(
t *testing.T,
cs *testCertState,
verifier CertVerifier,
initiator bool,
localIndex uint32,
) *Machine {
t.Helper()
m, err := NewMachine(
cs.version, cs.getCredential,
verifier, func() (uint32, error) { return localIndex, nil },
initiator, header.HandshakeIXPSK0,
)
require.NoError(t, err)
return m
}
func initiateHandshake(
t *testing.T,
initCS *testCertState, initVerifier CertVerifier,
respCS *testCertState, respVerifier CertVerifier,
) (initM, respM *Machine, respResult *Result, resp []byte, err error) {
t.Helper()
initM = newTestMachine(t, initCS, initVerifier, true, 100)
msg1, merr := initM.Initiate(nil)
require.NoError(t, merr)
respM = newTestMachine(t, respCS, respVerifier, false, 200)
resp, respResult, err = respM.ProcessPacket(nil, msg1)
return
}
func doFullHandshake(
t *testing.T, initCS, respCS *testCertState, caPool *cert.CAPool,
) (initResult, respResult *Result) {
t.Helper()
v := testVerifier(caPool)
initM := newTestMachine(t, initCS, v, true, 1000)
respM := newTestMachine(t, respCS, v, false, 2000)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
resp, respResult, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
require.NotNil(t, respResult)
require.NotEmpty(t, resp)
_, initResult, err = initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, initResult)
return initResult, respResult
}

444
handshake/machine.go Normal file
View File

@@ -0,0 +1,444 @@
package handshake
import (
"bytes"
"fmt"
"slices"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
)
// IndexAllocator is called by the Machine to allocate a local index for the
// handshake. It is called at most once, when the first outgoing message that
// carries a payload is built.
//
// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning
// "no index assigned" on the wire and in the payload-presence checks. If an
// allocator ever returned 0, a legitimate handshake's payload could be
// indistinguishable from an empty one and would be rejected.
type IndexAllocator func() (uint32, error)
// CertVerifier is called by the Machine after reconstructing the peer's
// certificate from the handshake. The verifier performs all validation
// (CA trust, expiry, policy checks, allow lists).
type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
// Result contains the results of a successful handshake.
// Returned by ProcessPacket when the handshake is complete.
type Result struct {
EKey *noise.CipherState
DKey *noise.CipherState
MyCert cert.Certificate
RemoteCert *cert.CachedCertificate
RemoteIndex uint32
LocalIndex uint32
HandshakeTime uint64
MessageIndex uint64 // number of messages exchanged during the handshake
Initiator bool
}
// Machine drives a Noise handshake through N messages. It handles Noise
// protocol operations, certificate reconstruction, and payload encoding.
// Certificate validation is delegated to the caller via CertVerifier.
//
// A Machine is not safe for concurrent use. The caller must ensure that
// Initiate and ProcessPacket are not called concurrently.
//
// Error contract: when ProcessPacket or Initiate returns an error, callers
// must check Failed() to decide what to do next. If Failed() is false the
// underlying noise state was not advanced (the packet was rejected before
// ReadMessage took effect, or the rejection is non-fatal like a stale
// retransmit) and the Machine can accept another packet. If Failed() is
// true the Machine is unrecoverable and the caller must abandon it.
type Machine struct {
hs *noise.HandshakeState
getCred GetCredentialFunc
allocIndex IndexAllocator
verifier CertVerifier
result *Result
msgs []msgFlags
myVersion cert.Version
subtype header.MessageSubType
indexAllocated bool
remoteCertSet bool
payloadSet bool
failed bool
}
// NewMachine creates a handshake state machine. The subtype determines both
// the noise pattern and the per-message content layout. The credential for
// `version` is fetched via getCred and used to seed the noise.HandshakeState.
// IndexAllocator is called lazily when the first outgoing payload is built.
func NewMachine(
version cert.Version,
getCred GetCredentialFunc,
verifier CertVerifier,
allocIndex IndexAllocator,
initiator bool,
subtype header.MessageSubType,
) (*Machine, error) {
info, err := subtypeInfoFor(subtype)
if err != nil {
return nil, err
}
cred := getCred(version)
if cred == nil {
return nil, fmt.Errorf("%w: %v", ErrNoCredential, version)
}
hs, err := cred.buildHandshakeState(initiator, info.pattern)
if err != nil {
return nil, fmt.Errorf("build noise state: %w", err)
}
return &Machine{
hs: hs,
subtype: subtype,
msgs: info.msgs,
getCred: getCred,
allocIndex: allocIndex,
verifier: verifier,
myVersion: version,
result: &Result{
Initiator: initiator,
},
}, nil
}
// Failed returns true if the Machine is in an unrecoverable state.
func (m *Machine) Failed() bool {
return m.failed
}
// Subtype returns the handshake subtype this Machine was built for.
func (m *Machine) Subtype() header.MessageSubType {
return m.subtype
}
// MessageIndex returns the noise handshake message index, which equals the
// wire counter of the most recently sent or received message.
func (m *Machine) MessageIndex() int {
return m.hs.MessageIndex()
}
// requireComplete checks that both a peer cert and payload have been received.
// Marks the machine as failed if not.
func (m *Machine) requireComplete() error {
if !m.payloadSet || !m.remoteCertSet {
m.failed = true
return ErrIncompleteHandshake
}
return nil
}
// myMsgFlags returns the flags for the current outgoing message.
func (m *Machine) myMsgFlags() msgFlags {
idx := m.hs.MessageIndex()
if idx < len(m.msgs) {
return m.msgs[idx]
}
return msgFlags{}
}
// peerMsgFlags returns the flags for the message we just read.
func (m *Machine) peerMsgFlags() msgFlags {
idx := m.hs.MessageIndex() - 1
if idx >= 0 && idx < len(m.msgs) {
return m.msgs[idx]
}
return msgFlags{}
}
// Initiate produces the first handshake message. Only valid for initiators,
// and must be called exactly once before ProcessPacket.
//
// out is a destination buffer the message is appended to and returned. Pass
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
// buf[:0]) with sufficient capacity to avoid allocation.
//
// An error return may not indicate a fatal condition, check Failed() to
// determine if the Machine can still be used.
func (m *Machine) Initiate(out []byte) ([]byte, error) {
if m.failed {
return nil, ErrMachineFailed
}
if !m.result.Initiator {
m.failed = true
return nil, ErrInitiateOnResponder
}
if m.hs.MessageIndex() != 0 {
m.failed = true
return nil, ErrInitiateAlreadyCalled
}
// At MessageIndex=0 with RemoteIndex still zero, buildResponse produces
// header counter 1 and remote index 0, which is what the initial message needs.
out, _, _, err := m.buildResponse(out)
if err != nil {
m.failed = true
return nil, err
}
return out, nil
}
// ProcessPacket handles an incoming handshake message. It advances the Noise
// state, validates the peer certificate via the verifier, and optionally
// produces a response.
//
// out is a destination buffer the response is appended to and returned. Pass
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
// buf[:0]) with sufficient capacity to avoid allocation. The returned slice
// is nil when no outgoing message is produced (handshake complete on this
// side, or final message of a multi-message pattern).
//
// Returns a non-nil Result when the handshake is complete.
// An error return may not indicate a fatal condition, check Failed() to
// determine if the Machine can still be used.
func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) {
if m.failed {
return nil, nil, ErrMachineFailed
}
if len(packet) < header.Len {
return nil, nil, ErrPacketTooShort
}
// Reject packets whose subtype doesn't match the one this Machine was
// built for. A pending handshake that suddenly receives a different
// subtype on its index is either a stray packet that matched by chance
// or a peer protocol violation; drop it without failing the Machine so
// the legitimate retransmit can still complete.
if header.MessageSubType(packet[1]) != m.subtype {
return nil, nil, ErrSubtypeMismatch
}
if m.result.Initiator && m.hs.MessageIndex() == 0 {
m.failed = true
return nil, nil, ErrInitiateNotCalled
}
// The (eKey, dKey) ordering here is correct for IX, where the initiator
// completes the handshake by reading the responder's stage-2 message.
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
// For 3-message patterns where a responder finishes by reading the final
// message, this ordering would be wrong; revisit when XX/pqIX lands.
msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:])
if err != nil {
// Noise ReadMessage failed. The noise library checkpoints and rolls back
// on failure, so the Machine is still alive. The caller can retry with
// a different packet.
return nil, nil, fmt.Errorf("noise ReadMessage: %w", err)
}
// From here on, noise state has advanced. Any error is fatal.
flags := m.peerMsgFlags()
if err := m.processPayload(msg, flags); err != nil {
return nil, nil, err
}
// If ReadMessage derived keys, the handshake is complete. Noise should
// always produce both keys together; asymmetry is a protocol invariant
// violation.
if eKey != nil || dKey != nil {
if eKey == nil || dKey == nil {
m.failed = true
return nil, nil, ErrAsymmetricCipherKeys
}
if err := m.requireComplete(); err != nil {
return nil, nil, err
}
return nil, m.completed(eKey, dKey), nil
}
// ReadMessage didn't complete, produce the next outgoing message
out, dk, ek, err := m.buildResponse(out)
if err != nil {
m.failed = true
return nil, nil, err
}
if ek != nil || dk != nil {
if ek == nil || dk == nil {
m.failed = true
return nil, nil, ErrAsymmetricCipherKeys
}
if err := m.requireComplete(); err != nil {
return nil, nil, err
}
return out, m.completed(ek, dk), nil
}
return out, nil, nil
}
func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result {
m.result.EKey = eKey
m.result.DKey = dKey
m.result.MessageIndex = uint64(m.hs.MessageIndex())
return m.result
}
func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
if len(msg) == 0 {
if flags.expectsPayload || flags.expectsCert {
m.failed = true
return ErrMissingContent
}
return nil
}
payload, err := UnmarshalPayload(msg)
if err != nil {
m.failed = true
return fmt.Errorf("unmarshal handshake: %w", err)
}
// Assert the payload contains exactly what we expect
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0
if hasPayloadData != flags.expectsPayload {
m.failed = true
return ErrUnexpectedContent
}
hasCertData := len(payload.Cert) > 0
if hasCertData != flags.expectsCert {
m.failed = true
return ErrUnexpectedContent
}
// Process payload
if flags.expectsPayload {
if m.result.Initiator {
m.result.RemoteIndex = payload.ResponderIndex
} else {
m.result.RemoteIndex = payload.InitiatorIndex
}
m.result.HandshakeTime = payload.Time
m.payloadSet = true
}
// Process certificate
if flags.expectsCert {
if err := m.validateCert(payload); err != nil {
return err
}
}
return nil
}
func (m *Machine) validateCert(payload Payload) error {
cred := m.getCred(m.myVersion)
if cred == nil {
m.failed = true
return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
}
rc, err := cert.Recombine(
cert.Version(payload.CertVersion),
payload.Cert,
m.hs.PeerStatic(),
cred.Cert.Curve(),
)
if err != nil {
m.failed = true
return fmt.Errorf("recombine cert: %w", err)
}
if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) {
m.failed = true
return ErrPublicKeyMismatch
}
// Version negotiation, if the peer sent a different version and we have it, switch
if rc.Version() != m.myVersion {
if m.getCred(rc.Version()) != nil {
m.myVersion = rc.Version()
}
}
verified, err := m.verifier(rc)
if err != nil {
m.failed = true
return fmt.Errorf("verify cert: %w", err)
}
m.result.RemoteCert = verified
m.remoteCertSet = true
return nil
}
func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
if !flags.expectsPayload && !flags.expectsCert {
return nil, nil
}
var p Payload
if flags.expectsPayload {
if !m.indexAllocated {
index, err := m.allocIndex()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err)
}
m.result.LocalIndex = index
m.indexAllocated = true
}
if m.result.Initiator {
p.InitiatorIndex = m.result.LocalIndex
} else {
p.ResponderIndex = m.result.LocalIndex
p.InitiatorIndex = m.result.RemoteIndex
}
p.Time = uint64(time.Now().UnixNano())
}
if flags.expectsCert {
cred := m.getCred(m.myVersion)
if cred == nil {
return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
}
p.Cert = cred.Bytes
p.CertVersion = uint32(cred.Cert.Version())
m.result.MyCert = cred.Cert
}
return MarshalPayload(nil, p), nil
}
func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
flags := m.myMsgFlags()
hsBytes, err := m.marshalOutgoing(flags)
if err != nil {
return nil, nil, nil, err
}
// Extend out by header.Len to make room for the header. slices.Grow is a
// no-op when the cap is already sufficient (the zero-copy case where the
// caller passed a pre-sized buffer). header.Encode overwrites the new
// bytes, so they don't need to be zeroed.
start := len(out)
out = slices.Grow(out, header.Len)[:start+header.Len]
header.Encode(
out[start:],
header.Version, header.Handshake, m.subtype,
m.result.RemoteIndex,
uint64(m.hs.MessageIndex()+1),
)
// noise.WriteMessage appends the encrypted handshake message to out,
// reusing capacity when present.
//
// The (dKey, eKey) ordering here is correct for IX, where the responder
// completes the handshake by writing the stage-2 message. noise returns
// (cs1, cs2) where cs1 is the initiator->responder cipher (which is the
// responder's decrypt key). For 3-message patterns where an initiator
// finishes by writing the final message, this ordering would be wrong;
// revisit when XX/pqIX lands.
out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes)
if err != nil {
return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err)
}
return out, dKey, eKey, nil
}

662
handshake/machine_test.go Normal file
View File

@@ -0,0 +1,662 @@
package handshake
import (
"net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
ct "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/noiseutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMachineIXHappyPath(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
assert.Equal(t, "responder", initR.RemoteCert.Certificate.Name())
assert.Equal(t, "initiator", respR.RemoteCert.Certificate.Name())
assert.Equal(t, uint32(1000), initR.LocalIndex)
assert.Equal(t, uint32(2000), initR.RemoteIndex)
assert.Equal(t, uint32(2000), respR.LocalIndex)
assert.Equal(t, uint32(1000), respR.RemoteIndex)
assert.Equal(t, uint64(2), initR.MessageIndex, "IX has 2 messages")
assert.Equal(t, uint64(2), respR.MessageIndex, "IX has 2 messages")
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("hello"))
require.NoError(t, err)
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
require.NoError(t, err)
assert.Equal(t, []byte("hello"), pt1)
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("world"))
require.NoError(t, err)
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
require.NoError(t, err)
assert.Equal(t, []byte("world"), pt2)
}
func TestMachineInitiateErrors(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("initiate on responder", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
_, err := m.Initiate(nil)
require.ErrorIs(t, err, ErrInitiateOnResponder)
assert.True(t, m.Failed())
})
t.Run("initiate called twice", func(t *testing.T) {
m := newTestMachine(t, cs, v, true, 100)
_, err := m.Initiate(nil)
require.NoError(t, err)
_, err = m.Initiate(nil)
require.ErrorIs(t, err, ErrInitiateAlreadyCalled)
assert.True(t, m.Failed())
})
t.Run("process packet before initiate on initiator", func(t *testing.T) {
m := newTestMachine(t, cs, v, true, 100)
_, _, err := m.ProcessPacket(nil, make([]byte, 100))
require.ErrorIs(t, err, ErrInitiateNotCalled)
assert.True(t, m.Failed())
})
t.Run("calling failed machine", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
_, err := m.Initiate(nil) // fails: responder
require.Error(t, err)
_, err = m.Initiate(nil) // fails: already failed
require.ErrorIs(t, err, ErrMachineFailed)
})
}
func TestMachineProcessPacketErrors(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("packet too short", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
_, _, err := m.ProcessPacket(nil, []byte{1, 2, 3})
require.ErrorIs(t, err, ErrPacketTooShort)
assert.False(t, m.Failed(), "short packet should not kill machine")
})
t.Run("noise decryption failure is recoverable", func(t *testing.T) {
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
initM := newTestMachine(t, initCS, v, true, 100)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
respM := newTestMachine(t, cs, v, false, 200)
resp, _, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
corrupted := make([]byte, len(resp))
copy(corrupted, resp)
for i := header.Len; i < len(corrupted); i++ {
corrupted[i] ^= 0xff
}
_, _, err = initM.ProcessPacket(nil, corrupted)
require.Error(t, err)
assert.False(t, initM.Failed(), "noise failure should be recoverable")
// And the machine should still complete a real handshake afterward.
_, result, err := initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, result, "initiator should complete on the legitimate response")
})
t.Run("invalid cert is fatal", func(t *testing.T) {
otherCA, _, otherCAKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
otherCS := newTestCertState(t, otherCA, otherCAKey, "other", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initM := newTestMachine(t, otherCS, testVerifier(ct.NewTestCAPool(otherCA)), true, 100)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
respM := newTestMachine(t, cs, v, false, 200)
_, _, err = respM.ProcessPacket(nil, msg1)
require.Error(t, err)
assert.True(t, respM.Failed(), "cert validation failure should kill machine")
})
t.Run("subtype mismatch is recoverable", func(t *testing.T) {
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
initM := newTestMachine(t, initCS, v, true, 100)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
// Mutate the subtype byte (offset 1 in the header) to a value the
// responder Machine wasn't built for.
bad := make([]byte, len(msg1))
copy(bad, msg1)
bad[1] = 0xff
respM := newTestMachine(t, cs, v, false, 200)
_, _, err = respM.ProcessPacket(nil, bad)
require.ErrorIs(t, err, ErrSubtypeMismatch)
assert.False(t, respM.Failed(), "subtype mismatch should not kill the machine")
// And the machine should still complete a real handshake afterward.
resp, result, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
require.NotNil(t, result, "responder should complete on the legitimate stage-1 packet")
assert.NotEmpty(t, resp, "responder should produce a stage-2 reply")
})
}
// TestMachineProcessPayload exercises processPayload's internal validation
// directly. Most of these failure modes can't be reached black-box once the
// subtype check at the top of ProcessPacket gates external callers, so we
// drive them by hand here for coverage.
func TestMachineProcessPayload(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("empty message with expects fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.processPayload(nil, msgFlags{expectsPayload: true, expectsCert: true})
require.ErrorIs(t, err, ErrMissingContent)
assert.True(t, m.Failed())
})
t.Run("empty message with no expects passes", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.processPayload(nil, msgFlags{})
require.NoError(t, err)
assert.False(t, m.Failed())
})
t.Run("malformed protobuf is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.processPayload([]byte{0xff, 0xff, 0xff}, msgFlags{expectsPayload: true, expectsCert: true})
require.Error(t, err)
assert.True(t, m.Failed())
})
t.Run("unexpected payload data is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
// A payload with index data when none was expected.
bytes := MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1})
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
require.ErrorIs(t, err, ErrUnexpectedContent)
assert.True(t, m.Failed())
})
t.Run("unexpected cert data is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
// A payload with cert when none was expected.
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
require.ErrorIs(t, err, ErrUnexpectedContent)
assert.True(t, m.Failed())
})
t.Run("missing payload data when expected is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
// Cert present, but no index/time fields.
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
err := m.processPayload(bytes, msgFlags{expectsPayload: true, expectsCert: true})
require.ErrorIs(t, err, ErrUnexpectedContent)
assert.True(t, m.Failed())
})
}
// TestMachineRequireComplete checks the fail-on-incomplete-handshake path
// directly. Like processPayload above this isn't reachable from a normal IX
// flow, so we drive it by hand.
func TestMachineRequireComplete(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("missing both fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.requireComplete()
require.ErrorIs(t, err, ErrIncompleteHandshake)
assert.True(t, m.Failed())
})
t.Run("payload only fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
m.payloadSet = true
err := m.requireComplete()
require.ErrorIs(t, err, ErrIncompleteHandshake)
assert.True(t, m.Failed())
})
t.Run("cert only fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
m.remoteCertSet = true
err := m.requireComplete()
require.ErrorIs(t, err, ErrIncompleteHandshake)
assert.True(t, m.Failed())
})
t.Run("both set passes", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
m.payloadSet = true
m.remoteCertSet = true
err := m.requireComplete()
require.NoError(t, err)
assert.False(t, m.Failed())
})
}
func TestMachineAESCipher(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertStateWithCipher(
t, ca, caKey, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
noiseutil.CipherAESGCM,
)
respCS := newTestCertStateWithCipher(
t, ca, caKey, "resp",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
noiseutil.CipherAESGCM,
)
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("works"))
require.NoError(t, err)
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
require.NoError(t, err)
assert.Equal(t, []byte("works"), pt1)
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("back"))
require.NoError(t, err)
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
require.NoError(t, err)
assert.Equal(t, []byte("back"), pt2)
}
func TestResultFields(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
assert.True(t, initR.Initiator)
assert.False(t, respR.Initiator)
assert.NotZero(t, initR.HandshakeTime)
assert.NotZero(t, respR.HandshakeTime)
assert.NotNil(t, initR.RemoteCert)
assert.NotNil(t, respR.RemoteCert)
}
func TestMachineBufferReuse(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
v := testVerifier(caPool)
initM := newTestMachine(t, initCS, v, true, 1000)
respM := newTestMachine(t, respCS, v, false, 2000)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
t.Run("response writes into provided buffer", func(t *testing.T) {
buf := make([]byte, 0, 4096)
resp, result, err := respM.ProcessPacket(buf, msg1)
require.NoError(t, err)
require.NotNil(t, result)
assert.NotEmpty(t, resp, "response should have content")
assert.Equal(t, &buf[:1][0], &resp[:1][0],
"response should reuse the provided buffer's backing array")
})
t.Run("initiate writes into provided buffer", func(t *testing.T) {
initM2 := newTestMachine(t, initCS, v, true, 3000)
buf := make([]byte, 0, 4096)
msg, err := initM2.Initiate(buf)
require.NoError(t, err)
assert.NotEmpty(t, msg, "initiate should have content")
assert.Equal(t, &buf[:1][0], &msg[:1][0],
"initiate should reuse the provided buffer's backing array")
})
t.Run("nil out still works", func(t *testing.T) {
initM2 := newTestMachine(t, initCS, v, true, 4000)
respM2 := newTestMachine(t, respCS, v, false, 5000)
msg1, err := initM2.Initiate(nil)
require.NoError(t, err)
resp, _, err := respM2.ProcessPacket(nil, msg1)
require.NoError(t, err)
out, result, err := initM2.ProcessPacket(nil, resp)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Nil(t, out, "initiator should have no response for IX msg2")
})
}
func TestMachineMsgIndexTracking(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
v := testVerifier(caPool)
initM := newTestMachine(t, initCS, v, true, 100)
respM := newTestMachine(t, respCS, v, false, 200)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
resp1, result1, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
assert.NotNil(t, result1)
_, result2, err := initM.ProcessPacket(nil, resp1)
require.NoError(t, err)
assert.NotNil(t, result2)
}
func TestMachineThreeMessagePattern(t *testing.T) {
registerTestXXInfo(t)
// Use HandshakeXX (3 messages) to verify the Machine handles multi-message
// patterns correctly. XX flow:
// msg1 (I->R): [E] - payload only, no cert
// msg2 (R->I): [E, ee, S, es] - payload + cert
// msg3 (I->R): [S, se] - cert only (no payload, not first two)
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
v := testVerifier(caPool)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initM, err := NewMachine(
cert.Version2,
initCS.getCredential, v,
func() (uint32, error) { return 1000, nil },
true, header.HandshakeXXPSK0,
)
require.NoError(t, err)
respM, err := NewMachine(
cert.Version2,
respCS.getCredential, v,
func() (uint32, error) { return 2000, nil },
false, header.HandshakeXXPSK0,
)
require.NoError(t, err)
// msg1: initiator -> responder (E only, no cert)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
assert.NotEmpty(t, msg1)
// Responder processes msg1, should not complete yet, should produce msg2
msg2, result, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
assert.Nil(t, result, "XX should not complete on msg1")
assert.NotEmpty(t, msg2, "responder should produce msg2")
// Initiator processes msg2: gets responder's cert, produces msg3, and
// completes (WriteMessage for msg3 derives keys)
msg3, initResult, err := initM.ProcessPacket(nil, msg2)
require.NoError(t, err)
require.NotNil(t, initResult, "XX initiator should complete after reading msg2 and writing msg3")
assert.NotEmpty(t, msg3, "initiator should produce msg3")
assert.Equal(t, "resp", initResult.RemoteCert.Certificate.Name())
// Responder processes msg3: gets initiator's cert and completes
_, respResult, err := respM.ProcessPacket(nil, msg3)
require.NoError(t, err)
require.NotNil(t, respResult, "XX responder should complete on msg3")
assert.Equal(t, "init", respResult.RemoteCert.Certificate.Name())
assert.Equal(t, uint64(3), initResult.MessageIndex, "XX has 3 messages")
assert.Equal(t, uint64(3), respResult.MessageIndex, "XX has 3 messages")
// Verify keys work
ct1, err := initResult.EKey.Encrypt(nil, nil, []byte("three messages"))
require.NoError(t, err)
pt1, err := respResult.DKey.Decrypt(nil, nil, ct1)
require.NoError(t, err)
assert.Equal(t, []byte("three messages"), pt1)
}
// NOTE: ErrIncompleteHandshake is tested implicitly. It can't be triggered with
// IX since the cert is always in the payload. A 3-message pattern test (HybridIX)
// should exercise the case where cert arrives in msg3 and verify that completing
// without it fails.
func TestMachineExpiredCert(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519,
time.Now().Add(-24*time.Hour), time.Now().Add(24*time.Hour),
nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
expCert, _, expKeyPEM, _ := ct.NewTestCert(
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
"expired", time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour),
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, nil, nil,
)
expKey, _, _, err := cert.UnmarshalPrivateKeyFromPEM(expKeyPEM)
require.NoError(t, err)
expHsBytes, err := expCert.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
expiredCS := &testCertState{
version: cert.Version2,
creds: map[cert.Version]*Credential{
cert.Version2: NewCredential(expCert, expHsBytes, expKey, ncs),
},
}
respCS := newTestCertState(
t, ca, caKey, "responder",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
)
_, respM, _, _, err := initiateHandshake(
t, expiredCS, testVerifier(caPool),
respCS, testVerifier(caPool),
)
require.ErrorContains(t, err, "verify cert")
assert.True(t, respM.Failed())
}
func TestMachineNoCertNetworks(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
caHsBytes, err := ca.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
noNetCS := &testCertState{
version: cert.Version2,
creds: map[cert.Version]*Credential{
cert.Version2: NewCredential(ca, caHsBytes, caKey, ncs),
},
}
respCS := newTestCertState(
t, ca, caKey, "responder",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
)
_, respM, _, _, err := initiateHandshake(
t, noNetCS, testVerifier(caPool),
respCS, testVerifier(caPool),
)
require.Error(t, err)
assert.True(t, respM.Failed())
}
func TestMachineDifferentCAs(t *testing.T) {
ca1, _, caKey1, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
ca2, _, caKey2, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
initCS := newTestCertState(
t, ca1, caKey1, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
)
respCS := newTestCertState(
t, ca2, caKey2, "resp",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
)
_, respM, _, _, err := initiateHandshake(
t, initCS, testVerifier(ct.NewTestCAPool(ca1)),
respCS, testVerifier(ct.NewTestCAPool(ca2)),
)
require.ErrorContains(t, err, "verify cert")
assert.True(t, respM.Failed())
}
func TestMachineVersionNegotiation(t *testing.T) {
ca1, _, caKey1, _ := ct.NewTestCaCert(
cert.Version1, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
ca2, _, caKey2, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca1, ca2)
makeMultiVersionResp := func(t *testing.T) *testCertState {
t.Helper()
respCertV1, _, respKeyPEM, _ := ct.NewTestCert(
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
ca1.NotBefore(), ca1.NotAfter(),
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
)
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
respCertV2, _ := ct.NewTestCertDifferentVersion(respCertV1, cert.Version2, ca2, caKey2)
respHsV1, _ := respCertV1.MarshalForHandshakes()
respHsV2, _ := respCertV2.MarshalForHandshakes()
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
return &testCertState{
version: cert.Version1,
creds: map[cert.Version]*Credential{
cert.Version1: NewCredential(respCertV1, respHsV1, respKey, ncs),
cert.Version2: NewCredential(respCertV2, respHsV2, respKey, ncs),
},
}
}
t.Run("responder matches initiator version", func(t *testing.T) {
initCS := newTestCertState(
t, ca2, caKey2, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
)
respCS := makeMultiVersionResp(t)
v := testVerifier(caPool)
initM, _, respResult, resp, err := initiateHandshake(
t, initCS, v,
respCS, v,
)
require.NoError(t, err)
require.NotNil(t, respResult)
assert.Equal(t, cert.Version2, respResult.MyCert.Version(),
"responder should negotiate to initiator's version")
_, initResult, err := initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, initResult)
assert.Equal(t, cert.Version2, initResult.RemoteCert.Certificate.Version(),
"initiator should see V2 cert from responder")
})
t.Run("responder keeps version when no match available", func(t *testing.T) {
initCS := newTestCertState(
t, ca2, caKey2, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
)
respCert, _, respKeyPEM, _ := ct.NewTestCert(
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
ca1.NotBefore(), ca1.NotAfter(),
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
)
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
respHs, _ := respCert.MarshalForHandshakes()
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
respCS := &testCertState{
version: cert.Version1,
creds: map[cert.Version]*Credential{
cert.Version1: NewCredential(respCert, respHs, respKey, ncs),
},
}
v := testVerifier(caPool)
_, _, respResult, _, err := initiateHandshake(
t, initCS, v,
respCS, v,
)
require.NoError(t, err)
require.NotNil(t, respResult)
assert.Equal(t, cert.Version1, respResult.MyCert.Version(),
"responder should keep V1 when V2 not available")
})
}

54
handshake/patterns.go Normal file
View File

@@ -0,0 +1,54 @@
package handshake
import (
"fmt"
"github.com/flynn/noise"
"github.com/slackhq/nebula/header"
)
// msgFlags tracks what application data a handshake message carries.
type msgFlags struct {
expectsPayload bool // message carries indexes and time
expectsCert bool // message carries the certificate
}
// subtypeInfo bundles the noise pattern with the per-message flags for a
// given handshake subtype.
type subtypeInfo struct {
pattern noise.HandshakePattern
msgs []msgFlags
}
// subtypeInfos defines the noise pattern and message content layout for each
// handshake subtype.
var subtypeInfos = map[header.MessageSubType]subtypeInfo{
// IX: 2 messages, both carry payload and cert
header.HandshakeIXPSK0: {
pattern: noise.HandshakeIX,
msgs: []msgFlags{
{expectsPayload: true, expectsCert: true},
{expectsPayload: true, expectsCert: true},
},
},
// XX: 3 messages
// msg1 (I->R): payload only
// msg2 (R->I): payload + cert
// msg3 (I->R): cert only
//header.HandshakeXXPSK0: {
// pattern: noise.HandshakeXX,
// msgs: []msgFlags{
// {expectsPayload: true, expectsCert: false},
// {expectsPayload: true, expectsCert: true},
// {expectsPayload: false, expectsCert: true},
// },
//},
}
func subtypeInfoFor(subtype header.MessageSubType) (subtypeInfo, error) {
if info, ok := subtypeInfos[subtype]; ok {
return info, nil
}
return subtypeInfo{}, fmt.Errorf("%w: %d", ErrUnknownSubtype, subtype)
}

View File

@@ -0,0 +1,63 @@
package handshake
import (
"testing"
"github.com/flynn/noise"
"github.com/slackhq/nebula/header"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSubtypeInfo(t *testing.T) {
t.Run("IX", func(t *testing.T) {
info, err := subtypeInfoFor(header.HandshakeIXPSK0)
require.NoError(t, err)
assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name)
require.Len(t, info.msgs, 2)
// msg1: payload + cert
assert.True(t, info.msgs[0].expectsPayload)
assert.True(t, info.msgs[0].expectsCert)
// msg2: payload + cert
assert.True(t, info.msgs[1].expectsPayload)
assert.True(t, info.msgs[1].expectsCert)
})
t.Run("XX", func(t *testing.T) {
registerTestXXInfo(t)
info, err := subtypeInfoFor(header.HandshakeXXPSK0)
require.NoError(t, err)
assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name)
require.Len(t, info.msgs, 3)
// msg1: payload only
assert.True(t, info.msgs[0].expectsPayload)
assert.False(t, info.msgs[0].expectsCert)
// msg2: payload + cert
assert.True(t, info.msgs[1].expectsPayload)
assert.True(t, info.msgs[1].expectsCert)
// msg3: cert only
assert.False(t, info.msgs[2].expectsPayload)
assert.True(t, info.msgs[2].expectsCert)
})
t.Run("unknown subtype returns error", func(t *testing.T) {
_, err := subtypeInfoFor(99)
require.ErrorIs(t, err, ErrUnknownSubtype)
})
}
// registerTestXXInfo temporarily registers XX subtype info for testing.
func registerTestXXInfo(t *testing.T) {
t.Helper()
subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{
pattern: noise.HandshakeXX,
msgs: []msgFlags{
{expectsPayload: true, expectsCert: false},
{expectsPayload: true, expectsCert: true},
{expectsPayload: false, expectsCert: true},
},
}
t.Cleanup(func() {
delete(subtypeInfos, header.HandshakeXXPSK0)
})
}

173
handshake/payload.go Normal file
View File

@@ -0,0 +1,173 @@
package handshake
import (
"errors"
"math"
"google.golang.org/protobuf/encoding/protowire"
)
var (
errInvalidHandshakeMessage = errors.New("invalid handshake message")
errInvalidHandshakeDetails = errors.New("invalid handshake details")
)
// Payload represents the decoded fields of a handshake message.
// Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
type Payload struct {
Cert []byte
InitiatorIndex uint32
ResponderIndex uint32
Time uint64
CertVersion uint32
}
// Proto field numbers for NebulaHandshakeDetails
const (
fieldCert = 1 // bytes
fieldInitiatorIndex = 2 // uint32
fieldResponderIndex = 3 // uint32
fieldTime = 5 // uint64
fieldCertVersion = 8 // uint32
)
// MarshalPayload encodes a handshake payload in protobuf wire format compatible
// with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
// Returns out (which may be nil), with the marshalled Payload appended to it.
func MarshalPayload(out []byte, p Payload) []byte {
var details []byte
if len(p.Cert) > 0 {
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
details = protowire.AppendBytes(details, p.Cert)
}
if p.InitiatorIndex != 0 {
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.InitiatorIndex))
}
if p.ResponderIndex != 0 {
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.ResponderIndex))
}
if p.Time != 0 {
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
details = protowire.AppendVarint(details, p.Time)
}
if p.CertVersion != 0 {
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.CertVersion))
}
out = protowire.AppendTag(out, 1, protowire.BytesType)
out = protowire.AppendBytes(out, details)
return out
}
// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message.
func UnmarshalPayload(b []byte) (Payload, error) {
var p Payload
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
if n < 0 {
return p, errInvalidHandshakeMessage
}
b = b[n:]
switch {
case num == 1 && typ == protowire.BytesType:
details, n := protowire.ConsumeBytes(b)
if n < 0 {
return p, errInvalidHandshakeMessage
}
b = b[n:]
if err := unmarshalPayloadDetails(&p, details); err != nil {
return p, err
}
default:
n := protowire.ConsumeFieldValue(num, typ, b)
if n < 0 {
return p, errInvalidHandshakeMessage
}
b = b[n:]
}
}
return p, nil
}
func unmarshalPayloadDetails(p *Payload, b []byte) error {
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
if n < 0 {
return errInvalidHandshakeDetails
}
b = b[n:]
// For known field numbers, reject any non-matching wire type as a
// hard error rather than silently skipping. The caller will catch
// missing-field cases downstream, but a wire-type mismatch on a tag
// we know is a peer protocol violation worth flagging here.
// Repeated occurrences of a singular field follow proto3 last-wins.
switch num {
case fieldCert:
if typ != protowire.BytesType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return errInvalidHandshakeDetails
}
p.Cert = append([]byte(nil), v...)
b = b[n:]
case fieldInitiatorIndex:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.InitiatorIndex = uint32(v)
b = b[n:]
case fieldResponderIndex:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.ResponderIndex = uint32(v)
b = b[n:]
case fieldTime:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 {
return errInvalidHandshakeDetails
}
p.Time = v
b = b[n:]
case fieldCertVersion:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.CertVersion = uint32(v)
b = b[n:]
default:
n := protowire.ConsumeFieldValue(num, typ, b)
if n < 0 {
return errInvalidHandshakeDetails
}
b = b[n:]
}
}
return nil
}

361
handshake/payload_test.go Normal file
View File

@@ -0,0 +1,361 @@
package handshake
import (
"bytes"
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protowire"
)
func TestPayloadRoundTrip(t *testing.T) {
t.Run("all fields set", func(t *testing.T) {
data := MarshalPayload(nil, Payload{
Cert: []byte("test-cert-bytes"),
CertVersion: 2,
InitiatorIndex: 12345,
ResponderIndex: 67890,
Time: 1234567890,
})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, []byte("test-cert-bytes"), got.Cert)
assert.Equal(t, uint32(12345), got.InitiatorIndex)
assert.Equal(t, uint32(67890), got.ResponderIndex)
assert.Equal(t, uint64(1234567890), got.Time)
assert.Equal(t, uint32(2), got.CertVersion)
})
t.Run("minimal fields", func(t *testing.T) {
data := MarshalPayload(nil, Payload{InitiatorIndex: 1})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(1), got.InitiatorIndex)
assert.Equal(t, uint32(0), got.ResponderIndex)
assert.Equal(t, uint64(0), got.Time)
assert.Nil(t, got.Cert)
})
t.Run("empty payload", func(t *testing.T) {
data := MarshalPayload(nil, Payload{})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(0), got.InitiatorIndex)
})
t.Run("large cert bytes", func(t *testing.T) {
bigCert := make([]byte, 4096)
for i := range bigCert {
bigCert[i] = byte(i % 256)
}
data := MarshalPayload(nil, Payload{
Cert: bigCert,
CertVersion: 2,
InitiatorIndex: 999,
})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, bigCert, got.Cert)
assert.Equal(t, uint32(999), got.InitiatorIndex)
})
t.Run("append to existing buffer", func(t *testing.T) {
prefix := []byte("prefix")
data := MarshalPayload(prefix, Payload{InitiatorIndex: 42})
assert.Equal(t, []byte("prefix"), data[:6])
got, err := UnmarshalPayload(data[6:])
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
}
func TestPayloadUnknownFields(t *testing.T) {
t.Run("unknown field in outer message is skipped", func(t *testing.T) {
// Marshal a normal payload then append an unknown field (field 99, varint)
data := MarshalPayload(nil, Payload{InitiatorIndex: 42})
data = protowire.AppendTag(data, 99, protowire.VarintType)
data = protowire.AppendVarint(data, 12345)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
t.Run("unknown field in details is skipped", func(t *testing.T) {
// Build details with a known field + unknown field
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 77)
// Unknown field 50, varint
details = protowire.AppendTag(details, 50, protowire.VarintType)
details = protowire.AppendVarint(details, 9999)
// Another known field after the unknown one
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 88)
// Wrap in outer message
var data []byte
data = protowire.AppendTag(data, 1, protowire.BytesType)
data = protowire.AppendBytes(data, details)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(77), got.InitiatorIndex)
assert.Equal(t, uint32(88), got.ResponderIndex)
})
t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) {
// Fields 6 and 7 are reserved in the proto definition
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 100)
details = protowire.AppendTag(details, 6, protowire.VarintType)
details = protowire.AppendVarint(details, 1)
details = protowire.AppendTag(details, 7, protowire.VarintType)
details = protowire.AppendVarint(details, 2)
var data []byte
data = protowire.AppendTag(data, 1, protowire.BytesType)
data = protowire.AppendBytes(data, details)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(100), got.InitiatorIndex)
})
}
func TestPayloadBytesConsumed(t *testing.T) {
t.Run("all bytes consumed on valid input", func(t *testing.T) {
original := Payload{
Cert: []byte("cert"),
CertVersion: 2,
InitiatorIndex: 100,
ResponderIndex: 200,
Time: 999,
}
data := MarshalPayload(nil, original)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
// Re-marshal and compare — proves we consumed and reproduced all fields
remarshaled := MarshalPayload(nil, got)
assert.Equal(t, data, remarshaled)
})
}
// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope
// so UnmarshalPayload can reach unmarshalPayloadDetails.
func wrapDetails(details []byte) []byte {
var out []byte
out = protowire.AppendTag(out, 1, protowire.BytesType)
out = protowire.AppendBytes(out, details)
return out
}
func TestPayloadUnmarshalErrors(t *testing.T) {
t.Run("nil input", func(t *testing.T) {
got, err := UnmarshalPayload(nil)
require.NoError(t, err)
assert.Equal(t, uint32(0), got.InitiatorIndex)
})
t.Run("truncated outer tag", func(t *testing.T) {
_, err := UnmarshalPayload([]byte{0x80})
assert.Error(t, err)
})
t.Run("truncated outer details field", func(t *testing.T) {
_, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05})
assert.Error(t, err)
})
t.Run("truncated outer unknown field", func(t *testing.T) {
// Valid tag for unknown field 99 varint, but no value follows
var data []byte
data = protowire.AppendTag(data, 99, protowire.VarintType)
_, err := UnmarshalPayload(data)
assert.Error(t, err)
})
t.Run("truncated details tag", func(t *testing.T) {
_, err := UnmarshalPayload(wrapDetails([]byte{0x80}))
assert.Error(t, err)
})
t.Run("truncated cert bytes", func(t *testing.T) {
// Field 1 (cert), bytes type, length 10 but only 2 bytes
var details []byte
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated initiator index varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = append(details, 0x80) // incomplete varint
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated responder index varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated time varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated cert version varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated unknown field in details", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, 50, protowire.VarintType)
details = append(details, 0x80) // incomplete varint
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert with wrong wire type rejected", func(t *testing.T) {
// fieldCert as Varint instead of Bytes.
var details []byte
details = protowire.AppendTag(details, fieldCert, protowire.VarintType)
details = protowire.AppendVarint(details, 42)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("initiator index with wrong wire type rejected", func(t *testing.T) {
// fieldInitiatorIndex as Bytes instead of Varint.
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("time with wrong wire type rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldTime, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert version with wrong wire type rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) {
// Per proto3, multiple instances of a singular field are accepted and
// the last value wins. We keep this behavior so that peers using
// alternative encoders aren't rejected.
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 1)
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 42)
got, err := UnmarshalPayload(wrapDetails(details))
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
t.Run("initiator index varint overflow rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, math.MaxUint32+1)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert version varint overflow rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = protowire.AppendVarint(details, math.MaxUint32+1)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
}
// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it
// never panics, and for any input that parses cleanly, that re-marshal +
// re-parse is a fix-point. Inputs come from an authenticated peer (post-
// noise-decrypt), so the threat model is "valid peer behaving arbitrarily,"
// not "unauthenticated injection."
func FuzzPayload(f *testing.F) {
// Seed corpus with a handful of known-good shapes.
f.Add(MarshalPayload(nil, Payload{}))
f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}))
f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}))
f.Add(MarshalPayload(nil, Payload{
Cert: []byte("seed-cert"),
InitiatorIndex: 1,
ResponderIndex: 2,
Time: 3,
CertVersion: 2,
}))
f.Add([]byte{})
f.Add([]byte{0xff})
f.Fuzz(func(t *testing.T, data []byte) {
p1, err := UnmarshalPayload(data)
if err != nil {
return
}
// For any input that parses, re-marshaling and re-parsing must
// yield an equivalent Payload. This catches dispatch bugs (e.g.
// emitting a field on marshal that we don't accept on parse) and
// any non-idempotent parsing behavior.
b2 := MarshalPayload(nil, p1)
p2, err := UnmarshalPayload(b2)
if err != nil {
t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2)
}
if !payloadsEqual(p1, p2) {
t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2)
}
})
}
func payloadsEqual(a, b Payload) bool {
return bytes.Equal(a.Cert, b.Cert) &&
a.InitiatorIndex == b.InitiatorIndex &&
a.ResponderIndex == b.ResponderIndex &&
a.Time == b.Time &&
a.CertVersion == b.CertVersion
}

View File

@@ -1,746 +0,0 @@
package nebula
import (
"bytes"
"net/netip"
"time"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
)
// NOISE IX Handshakes
// This function constructs a handshake packet, but does not actually send it
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return false
}
cs := f.pki.getCertState()
v := cs.initiatingVersion
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
// If we're connecting to a v6 address we should encourage use of a V2 cert
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
}
crt := cs.getCertificate(v)
if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate is available")
return false
}
crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return false
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Failed to create connection state")
return false
}
hh.hostinfo.ConnectionState = ci
hs := &NebulaHandshake{
Details: &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: crtHs,
CertVersion: uint32(v),
},
}
if f.multiPort.Tx || f.multiPort.Rx {
hs.Details.InitiatorMultiPort = &MultiPortDetails{
RxSupported: f.multiPort.Rx,
TxSupported: f.multiPort.Tx,
BasePort: uint32(f.multiPort.TxBasePort),
TotalPorts: uint32(f.multiPort.TxPorts),
}
}
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false
}
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return false
}
// We are sending handshake packet 1, so we don't expect to receive
// handshake packet 1 from the responder
ci.window.Update(f.l, 1)
hh.hostinfo.HandshakePacket[0] = msg
hh.ready = true
return true
}
func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) {
cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available")
return
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
}
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
}
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return
}
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
return
}
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, fperr := rc.Fingerprint()
if fperr != nil {
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp)
if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
}
e.Info("Invalid certificate from host")
return
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
return
}
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{
"from": via,
"handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
}
} else {
// Record the certificate we are actually using
ci.myCert = myCertOtherVersion
}
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("No networks in certificate")
return
}
certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
vpnNetworks := remoteCert.Certificate.Networks()
anyVpnAddrsInCommon := false
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return
}
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) {
anyVpnAddrsInCommon = true
}
}
if !via.IsRelayed {
// We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
}
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
return
}
var multiportTx, multiportRx bool
if f.multiPort.Rx || f.multiPort.Tx {
if hs.Details.InitiatorMultiPort != nil {
multiportTx = hs.Details.InitiatorMultiPort.RxSupported && f.multiPort.Tx
multiportRx = hs.Details.InitiatorMultiPort.TxSupported && f.multiPort.Rx
}
hs.Details.ResponderMultiPort = &MultiPortDetails{
TxSupported: f.multiPort.Tx,
RxSupported: f.multiPort.Rx,
BasePort: uint32(f.multiPort.TxBasePort),
TotalPorts: uint32(f.multiPort.TxPorts),
}
}
if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(via.UdpAddr.Port()) {
// The other side sent us a handshake from a different port, make sure
// we send responses back to the BasePort
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), uint16(hs.Details.InitiatorMultiPort.BasePort))
}
hostinfo := &HostInfo{
ConnectionState: ci,
localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex,
vpnAddrs: vpnAddrs,
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
multiportTx: multiportTx,
multiportRx: multiportRx,
relayState: RelayState{
relays: nil,
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs,
"from": via,
"certName": certName,
"certVersion": certVersion,
"fingerprint": fingerprint,
"issuer": issuer,
"initiatorIndex": hs.Details.InitiatorIndex,
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"multiportTx": multiportTx,
"multiportRx": multiportRx,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return
}
hs.Details.CertVersion = uint32(ci.myCert.Version())
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
}
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
return
}
hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:]))
copy(hostinfo.HandshakePacket[0], packet[header.Len:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
// We are sending handshake packet 2, so we don't expect to receive
// handshake packet 2 from the initiator.
ci.window.Update(f.l, 2)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
}
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil {
switch err {
case ErrAlreadySeen:
if hostinfo.multiportRx {
// The other host is sending to us with multiport, so only grab the IP
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port())
}
// Update remote if preferred
if existing.SetRemoteIfPreferred(f.hostMap, via) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed {
err := f.outside.WriteTo(msg, via.UdpAddr)
if multiportTx {
// TODO remove alloc here
raw := make([]byte, len(msg)+udp.RawOverhead)
copy(raw[udp.RawOverhead:], msg)
err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr)
} else {
err = f.outside.WriteTo(msg, via.UdpAddr)
}
if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return
} else {
if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
return
}
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
Error("Failed to add HostInfo due to localIndex collision")
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to add HostInfo to HostMap")
return
}
}
// Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed {
if multiportTx {
// TODO remove alloc here
raw := make([]byte, len(msg)+udp.RawOverhead)
copy(raw[udp.RawOverhead:], msg)
err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr)
} else {
err = f.outside.WriteTo(msg, via.UdpAddr)
}
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if err != nil {
log.WithError(err).Error("Failed to send handshake")
} else {
log.Info("Handshake message sent")
}
} else {
if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
// it's correctly marked as working.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
}
f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
return
}
func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
}
hh.Lock()
defer hh.Unlock()
hostinfo := hh.hostinfo
if !via.IsRelayed {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false
}
}
ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
// the handshake state machine. Tear it down
return true
}
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
}
if (f.multiPort.Tx || f.multiPort.Rx) && hs.Details.ResponderMultiPort != nil {
hostinfo.multiportTx = hs.Details.ResponderMultiPort.RxSupported && f.multiPort.Tx
hostinfo.multiportRx = hs.Details.ResponderMultiPort.TxSupported && f.multiPort.Rx
}
if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(via.UdpAddr.Port()) {
// The other side sent us a handshake from a different port, make sure
// we send responses back to the BasePort
via.UdpAddr = netip.AddrPortFrom(
via.UdpAddr.Addr(),
uint16(hs.Details.ResponderMultiPort.BasePort),
)
}
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
return true
}
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp).
WithField("certVpnNetworks", rc.Networks())
if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
}
e.Info("Invalid certificate from host")
return true
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
return true
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("No networks in certificate")
return true
}
vpnNetworks := remoteCert.Certificate.Networks()
certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
hostinfo.remoteIndexId = hs.Details.ResponderIndex
hostinfo.lastHandshakeTime = hs.Details.Time
// Store their cert and our symmetric keys
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
} else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
correctHostResponded := false
anyVpnAddrsInCommon := false
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) {
anyVpnAddrsInCommon = true
}
if hostinfo.vpnAddrs[0] == network.Addr() {
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
correctHostResponded = true
}
}
// Ensure the right host responded
if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
// Release our old handshake from pending, it should not continue
f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
//TODO is hostinfo.vpnAddrs[0] always the address to use?
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(via)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnNetworks", vpnNetworks).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient
newHH.packetStore = hh.packetStore
hh.packetStore = []*cachedPacket{}
// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnAddrs = vpnAddrs
f.sendCloseTunnel(hostinfo)
})
return true
}
// Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore)).
WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx)
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
// Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
}
if len(hh.packetStore) > 0 {
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for _, cp := range hh.packetStore {
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
}
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
}
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
f.metricHandshakes.Update(duration)
return false
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@ import (
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test"
@@ -27,7 +28,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
v1Credential: nil,
}
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -100,3 +101,137 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{initiatingVersion: cert.Version2}
}
func TestValidatePeerCert(t *testing.T) {
l := test.NewLogger()
myNetwork := netip.MustParsePrefix("10.0.0.1/24")
myAddrTable := new(bart.Lite)
myAddrTable.Insert(netip.PrefixFrom(myNetwork.Addr(), myNetwork.Addr().BitLen()))
myNetTable := new(bart.Lite)
myNetTable.Insert(myNetwork.Masked())
newHM := func() *HandshakeManager {
hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig)
hm.f = &Interface{
handshakeManager: hm,
pki: &PKI{},
l: l,
myVpnAddrsTable: myAddrTable,
myVpnNetworksTable: myNetTable,
lightHouse: hm.lightHouse,
}
return hm
}
cached := func(networks ...netip.Prefix) *cert.CachedCertificate {
return &cert.CachedCertificate{
Certificate: &dummyCert{name: "peer", networks: networks},
}
}
via := ViaSender{
UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"),
IsRelayed: true, // skip the remote allow list (covered separately)
}
t.Run("addr inside our networks sets anyVpnAddrsInCommon", func(t *testing.T) {
hm := newHM()
// 10.0.0.2 falls inside our 10.0.0.0/24
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.2/24")))
assert.True(t, ok)
assert.True(t, common)
assert.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.2")}, addrs)
})
t.Run("addr outside our networks leaves anyVpnAddrsInCommon false", func(t *testing.T) {
hm := newHM()
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("192.168.1.5/24")))
assert.True(t, ok)
assert.False(t, common)
assert.Equal(t, []netip.Addr{netip.MustParseAddr("192.168.1.5")}, addrs)
})
t.Run("any matching network is enough", func(t *testing.T) {
hm := newHM()
addrs, common, ok := hm.validatePeerCert(via, cached(
netip.MustParsePrefix("192.168.1.5/24"),
netip.MustParsePrefix("10.0.0.42/24"),
))
assert.True(t, ok)
assert.True(t, common)
assert.Len(t, addrs, 2)
})
t.Run("self-handshake is rejected", func(t *testing.T) {
hm := newHM()
// 10.0.0.1 is in myVpnAddrsTable
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.1/24")))
assert.False(t, ok)
assert.False(t, common)
assert.Nil(t, addrs)
})
t.Run("cert with no networks is rejected", func(t *testing.T) {
hm := newHM()
addrs, common, ok := hm.validatePeerCert(via, cached())
assert.False(t, ok)
assert.False(t, common)
assert.Nil(t, addrs)
})
}
func TestHandleIncomingDispatch(t *testing.T) {
l := test.NewLogger()
newHM := func() *HandshakeManager {
hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig)
hm.f = &Interface{
handshakeManager: hm,
pki: &PKI{},
l: l,
}
return hm
}
via := ViaSender{
UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"),
IsRelayed: true, // bypass remote allow list
}
// A packet body of zero length is fine for these tests: dispatch is
// gated on header fields, and we assert that we never reach noise/cert
// processing for any of the malformed shapes here.
pkt := make([]byte, header.Len)
t.Run("unsupported subtype dropped", func(t *testing.T) {
hm := newHM()
h := &header.H{Type: header.Handshake, Subtype: header.MessageSubType(99), MessageCounter: 1}
hm.HandleIncoming(via, pkt, h)
assert.Empty(t, hm.indexes, "no pending handshake should be created")
})
t.Run("stage-1 with non-zero RemoteIndex dropped", func(t *testing.T) {
hm := newHM()
h := &header.H{
Type: header.Handshake,
Subtype: header.HandshakeIXPSK0,
RemoteIndex: 0xdeadbeef,
MessageCounter: 1,
}
hm.HandleIncoming(via, pkt, h)
assert.Empty(t, hm.indexes, "spoofed stage-1 must not create a pending machine")
})
t.Run("continuation with no matching pending index dropped", func(t *testing.T) {
hm := newHM()
h := &header.H{
Type: header.Handshake,
Subtype: header.HandshakeIXPSK0,
RemoteIndex: 0xcafef00d,
MessageCounter: 2,
}
hm.HandleIncoming(via, pkt, h)
assert.Empty(t, hm.indexes, "orphan stage-2 must not create state")
})
}

View File

@@ -174,6 +174,10 @@ func (h *H) SubTypeName() string {
return SubTypeName(h.Type, h.Subtype)
}
func (h *H) IsValidSubType() bool {
return IsValidSubType(h.Type, h.Subtype)
}
// SubTypeName will transform a nebula message sub type into a human string
func SubTypeName(t MessageType, s MessageSubType) string {
if n, ok := subTypeMap[t]; ok {
@@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string {
return "unknown"
}
func IsValidSubType(t MessageType, s MessageSubType) bool {
if n, ok := subTypeMap[t]; ok {
if _, ok := (*n)[s]; ok {
return true
}
}
return false
}
// NewHeader turns bytes into a header
func NewHeader(b []byte) (*H, error) {
h := new(H)

View File

@@ -1,9 +1,11 @@
package nebula
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
@@ -13,10 +15,10 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
)
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
@@ -60,7 +62,7 @@ type HostMap struct {
RemoteIndexes map[uint32]*HostInfo
Hosts map[netip.Addr]*HostInfo
preferredRanges atomic.Pointer[[]netip.Prefix]
l *logrus.Logger
l *slog.Logger
}
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
@@ -319,7 +321,7 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap {
hm := newHostMap(l)
hm.reload(c, true)
@@ -327,13 +329,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
hm.reload(c, false)
})
l.WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created")
l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges())
return hm
}
func newHostMap(l *logrus.Logger) *HostMap {
func newHostMap(l *slog.Logger) *HostMap {
return &HostMap{
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
@@ -352,7 +353,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
hm.l.Warn("Failed to parse preferred ranges, ignoring",
"error", err,
"range", rawPreferredRanges,
)
continue
}
@@ -361,7 +365,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
if !initial {
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
hm.l.Info("preferred_ranges changed",
"oldPreferredRanges", *oldRanges,
"newPreferredRanges", preferredRanges,
)
}
}
}
@@ -494,10 +501,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad
hm.Indexes = map[uint32]*HostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.Debug("Hostmap hostInfo deleted",
"hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
)
}
if isLastHostinfo {
@@ -610,9 +618,9 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns {
if f.dnsServer != nil {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
f.dnsServer.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
}
for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
@@ -621,10 +629,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
Debug("Hostmap vpnIp added")
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.Debug("Hostmap vpnIp added",
"hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}},
)
}
}
@@ -790,18 +799,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica
}
}
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// logger returns a derived slog.Logger with per-hostinfo fields pre-bound.
func (i *HostInfo) logger(l *slog.Logger) *slog.Logger {
if i == nil {
return logrus.NewEntry(l)
return l
}
li := l.WithField("vpnAddrs", i.vpnAddrs).
WithField("localIndex", i.localIndexId).
WithField("remoteIndex", i.remoteIndexId)
li := l.With(
"vpnAddrs", i.vpnAddrs,
"localIndex", i.localIndexId,
"remoteIndex", i.remoteIndexId,
)
if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Certificate.Name())
li = li.With("certName", peerCert.Certificate.Name())
}
}
@@ -810,14 +822,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage
var finalAddrs []netip.Addr
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
if l.Level >= logrus.TraceLevel {
l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName")
if l.Enabled(context.Background(), logging.LevelTrace) {
l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName",
"interfaceName", i.Name,
"allow", allow,
)
}
if !allow {
@@ -835,8 +850,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
}
if !addr.IsValid() {
if l.Level >= logrus.DebugLevel {
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("addr was invalid", "localAddr", rawAddr)
}
continue
}
@@ -844,8 +859,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel {
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
if l.Enabled(context.Background(), logging.LevelTrace) {
l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow",
"localAddr", addr,
"allowed", isAllowed,
)
}
if !isAllowed {
continue

View File

@@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
func TestHostMap_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
c := config.NewC(test.NewLogger())
hm := NewHostMapFromConfig(l, c)

View File

@@ -1,5 +1,4 @@
//go:build e2e_testing
// +build e2e_testing
package nebula

119
inside.go
View File

@@ -1,9 +1,10 @@
package nebula
import (
"context"
"log/slog"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
@@ -15,8 +16,11 @@ import (
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while validating outbound packet",
"packet", packet,
"error", err,
)
}
return
}
@@ -36,7 +40,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
f.l.Error("Failed to forward to tun", "error", err)
}
}
// Otherwise, drop. On linux, we should never see these packets - Linux
@@ -55,10 +59,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
"vpnAddr", fwPacket.RemoteAddr,
"fwPacket", fwPacket,
)
}
return
}
@@ -73,11 +78,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping outbound packet",
"fwPacket", fwPacket,
"reason", dropReason,
)
}
}
}
@@ -94,7 +99,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
_, err := f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
f.l.Error("Failed to write to tun", "error", err)
}
}
@@ -109,11 +114,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
}
if len(out) > iputil.MaxRejectPacketSize {
if f.l.GetLevel() >= logrus.InfoLevel {
f.l.
WithField("packet", packet).
WithField("outPacket", out).
Info("rejectOutside: packet too big, not sending")
if f.l.Enabled(context.Background(), slog.LevelInfo) {
f.l.Info("rejectOutside: packet too big, not sending",
"packet", packet,
"outPacket", out,
)
}
return
}
@@ -185,10 +190,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
// This would also need to interact with unsafe_route updates through reloading the config or
// use of the use_system_route_table option
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("destination", destinationAddr).
WithField("originalGateway", gatewayAddr).
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways",
"destination", destinationAddr,
"originalGateway", gatewayAddr,
)
}
for i := range gateways {
@@ -214,17 +220,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
fp := &firewall.Packet{}
err := newPacket(p, false, fp)
if err != nil {
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
return
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).
WithField("reason", dropReason).
Debugln("dropping cached packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping cached packet",
"fwPacket", fp,
"reason", dropReason,
)
}
return
}
@@ -240,9 +247,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message
})
if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", vpnAddr).
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes",
"vpnAddr", vpnAddr,
)
}
return
}
@@ -298,12 +306,12 @@ func (f *Interface) SendVia(via *HostInfo,
if noiseutil.EncryptLockNeeded {
via.ConnectionState.writeLock.Unlock()
}
via.logger(f.l).
WithField("outCap", cap(out)).
WithField("payloadLen", len(ad)).
WithField("headerLen", len(out)).
WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()).
Error("SendVia out buffer not large enough for relay")
via.logger(f.l).Error("SendVia out buffer not large enough for relay",
"outCap", cap(out),
"payloadLen", len(ad),
"headerLen", len(out),
"cipherOverhead", via.ConnectionState.eKey.Overhead(),
)
return
}
@@ -323,12 +331,12 @@ func (f *Interface) SendVia(via *HostInfo,
via.ConnectionState.writeLock.Unlock()
}
if err != nil {
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
return
}
err = f.writers[0].WriteTo(out, via.remote)
if err != nil {
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
}
f.connectionManager.RelayUsed(relay.LocalIndex)
}
@@ -384,8 +392,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Lighthouse update triggered for punch due to rebind counter",
"vpnAddrs", hostinfo.vpnAddrs,
)
}
}
@@ -395,10 +405,12 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
"error", err,
"udpAddr", remote,
"counter", c,
"attemptedCounter", c,
)
return
}
@@ -411,8 +423,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
err = f.writers[q].WriteTo(out, remote)
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
"error", err,
"udpAddr", remote,
)
}
} else if hostinfo.remote.IsValid() {
if multiport {
@@ -423,8 +437,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
err = f.writers[q].WriteTo(out, hostinfo.remote)
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
"error", err,
"udpAddr", remote,
)
}
} else {
// Try to send via a relay
@@ -432,7 +448,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
"relay", relayIP,
"error", err,
)
continue
}
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)

View File

@@ -1,5 +1,4 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
// +build darwin dragonfly freebsd netbsd openbsd
package nebula

View File

@@ -1,5 +1,4 @@
//go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd
// +build !darwin,!dragonfly,!freebsd,!netbsd,!openbsd
package nebula

View File

@@ -5,15 +5,15 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -30,7 +30,7 @@ type InterfaceConfig struct {
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
DnsServer *dnsServer
HandshakeManager *HandshakeManager
lightHouse *LightHouse
connectionManager *connectionManager
@@ -47,7 +47,7 @@ type InterfaceConfig struct {
reQueryWait time.Duration
ConntrackCacheTimeout time.Duration
l *logrus.Logger
l *slog.Logger
}
type Interface struct {
@@ -58,7 +58,7 @@ type Interface struct {
firewall *Firewall
connectionManager *connectionManager
handshakeManager *HandshakeManager
serveDns bool
dnsServer *dnsServer
createTime time.Time
lightHouse *LightHouse
myBroadcastAddrsTable *bart.Lite
@@ -86,17 +86,25 @@ type Interface struct {
conntrackCacheTimeout time.Duration
ctx context.Context
writers []udp.Conn
readers []io.ReadWriteCloser
udpRaw *udp.RawConn
wg sync.WaitGroup
// fatalErr holds the first unexpected reader error that caused shutdown.
// nil means "no fatal error" (yet)
fatalErr atomic.Pointer[error]
// triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr
triggerShutdown func()
udpRaw *udp.RawConn
multiPort MultiPortConfig
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger
l *slog.Logger
}
type MultiPortConfig struct {
@@ -176,12 +184,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
cs := c.pki.getCertState()
ifce := &Interface{
ctx: ctx,
pki: c.pki,
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
firewall: c.Firewall,
serveDns: c.ServeDns,
dnsServer: c.DnsServer,
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
@@ -222,18 +231,21 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() {
func (f *Interface) activate() error {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
if err != nil {
f.l.WithError(err).Error("Failed to get udp listen address")
f.l.Error("Failed to get udp listen address", "error", err)
}
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
WithField("build", f.version).WithField("udpAddr", addr).
WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active")
f.l.Info("Nebula interface is active",
"interface", f.inside.Name(),
"networks", f.myVpnNetworks,
"build", f.version,
"udpAddr", addr,
"boringcrypto", boringEnabled(),
)
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
@@ -252,33 +264,58 @@ func (f *Interface) activate() {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
f.l.Fatal(err)
return err
}
}
f.readers[i] = reader
}
if err := f.inside.Activate(); err != nil {
f.wg.Add(1) // for us to wait on Close() to return
if err = f.inside.Activate(); err != nil {
f.wg.Done()
f.inside.Close()
f.l.Fatal(err)
return err
}
return nil
}
func (f *Interface) run() {
func (f *Interface) run() (func() error, error) {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
go f.listenOut(i)
f.wg.Go(func() {
f.listenOut(i)
})
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
f.wg.Go(func() {
f.listenIn(f.readers[i], i)
})
}
return func() error {
f.wg.Wait()
if e := f.fatalErr.Load(); e != nil {
return *e
}
return nil
}, nil
}
// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one
func (f *Interface) onFatal(err error) {
swapped := f.fatalErr.CompareAndSwap(nil, &err)
if !swapped {
return
}
if f.triggerShutdown != nil {
f.triggerShutdown()
}
}
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
if i > 0 {
li = f.writers[i]
@@ -286,42 +323,47 @@ func (f *Interface) listenOut(i int) {
li = f.outside
}
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
})
if err != nil && !f.closed.Load() {
f.l.Error("Error while reading inbound packet, closing", "error", err)
f.onFatal(err)
}
f.l.Debug("underlay reader is done", "reader", i)
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
if !f.closed.Load() {
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
f.onFatal(err)
}
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
break
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
}
f.l.Debug("overlay reader is done", "reader", i)
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -341,7 +383,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
if initial || c.HasChanged("pki.disconnect_invalid") {
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
if !initial {
f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load())
}
}
}
@@ -355,7 +397,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
f.l.Error("Error while creating firewall during reload", "error", err)
return
}
@@ -368,10 +410,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack",
"firewallHashes", fw.GetRuleHashes(),
"oldFirewallHashes", oldFw.GetRuleHashes(),
"rulesVersion", fw.rulesVersion,
)
} else {
fw.Conntrack = conntrack
}
@@ -379,10 +422,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
f.firewall = fw
oldFw.Destroy()
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed")
f.l.Info("New firewall has been installed",
"firewallHashes", fw.GetRuleHashes(),
"oldFirewallHashes", oldFw.GetRuleHashes(),
"rulesVersion", fw.rulesVersion,
)
}
func (f *Interface) reloadSendRecvError(c *config.C) {
@@ -404,8 +448,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
}
}
f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
Info("Loaded send_recv_error config")
f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String())
}
}
@@ -428,8 +471,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) {
}
}
f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()).
Info("Loaded accept_recv_error config")
f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String())
}
}
@@ -505,15 +547,23 @@ func (f *Interface) GetCertState() *CertState {
}
func (f *Interface) Close() error {
var errs []error
f.closed.Store(true)
for _, u := range f.writers {
// Release the udp readers
for i, u := range f.writers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing udp socket")
f.l.Error("Error while closing udp socket", "error", err, "writer", i)
errs = append(errs, err)
}
}
// Release the tun device
return f.inside.Close()
// Release the tun device (closing the tun also closes all readers)
closeErr := f.inside.Close()
if closeErr != nil {
errs = append(errs, closeErr)
}
f.wg.Done()
return errors.Join(errs...)
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
@@ -15,10 +16,10 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
)
@@ -69,18 +70,19 @@ type LightHouse struct {
// Addr's of relays that can be used by peers to access me
relaysForMe atomic.Pointer[[]netip.Addr]
queryChan chan netip.Addr
updateTrigger chan struct{}
queryChan chan netip.Addr
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *logrus.Logger
l *slog.Logger
}
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 {
@@ -105,6 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
nebulaPort: nebulaPort,
punchConn: pc,
punchy: p,
updateTrigger: make(chan struct{}, 1),
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
@@ -131,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
case *util.ContextualError:
v.Log(l)
case error:
l.WithError(err).Error("failed to reload lighthouse")
l.Error("failed to reload lighthouse", "error", err)
}
})
@@ -203,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
addr := addrs[0].Unmap()
if lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range",
"addr", rawAddr,
"entry", i+1,
)
continue
}
@@ -222,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
if !initial {
lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load())
lh.l.Info("lighthouse.interval changed",
"interval", lh.interval.Load(),
)
if lh.updateCancel != nil {
// May not always have a running routine
@@ -316,6 +323,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed")
lh.TriggerUpdate()
}
}
@@ -333,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
for _, v := range c.GetStringSlice("relay.relays", nil) {
configRIP, err := netip.ParseAddr(v)
if err != nil {
lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed")
lh.l.Warn("Parse relay from config failed",
"relay", v,
"error", err,
)
} else {
lh.l.WithField("relay", v).Info("Read relay from config")
lh.l.Info("Read relay from config", "relay", v)
relaysForMe = append(relaysForMe, configRIP)
}
}
@@ -360,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
}
if !lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not",
"vpnAddr", addr,
"networks", lh.myVpnNetworks,
)
}
out[i] = addr
}
@@ -432,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
}
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work",
"vpnAddr", vpnAddr,
"networks", lh.myVpnNetworks,
"entry", i+1,
)
}
vals, ok := v.([]any)
@@ -534,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok {
debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug)
for _, addr := range allVpnAddrs {
srm := lh.addrMap[addr]
if srm == rm {
delete(lh.addrMap, addr)
if lh.l.Level >= logrus.DebugLevel {
lh.l.Debugf("deleting %s from lighthouse.", addr)
if debugEnabled {
lh.l.Debug("deleting from lighthouse", "vpnAddr", addr)
}
}
}
@@ -656,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow")
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
"vpnAddrs", vpnAddrs,
"udpAddr", to,
"allow", allow,
)
}
if !allow {
return false
@@ -675,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
udpAddr := protoV4AddrPortToNetAddrPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
Trace("remoteAllowList.Allow")
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
"vpnAddr", vpnAddr,
"udpAddr", udpAddr,
"allow", allow,
)
}
if !allow {
@@ -695,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
udpAddr := protoV6AddrPortToNetAddrPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
Trace("remoteAllowList.Allow")
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
"vpnAddr", vpnAddr,
"udpAddr", udpAddr,
"allow", allow,
)
}
if !allow {
@@ -713,21 +739,14 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
l := lh.GetLighthouses()
for i := range l {
if l[i] == vpnAddr {
return true
}
}
return false
return slices.Contains(l, vpnAddr)
}
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
l := lh.GetLighthouses()
for i := range vpnAddrs {
for j := range l {
if l[j] == vpnAddrs[i] {
return true
}
if slices.Contains(l, vpnAddrs[i]) {
return true
}
}
return false
@@ -779,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
if v == cert.Version1 {
if !addr.Is4() {
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr).
Error("Can't query lighthouse for v6 address using a v1 protocol")
lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol",
"queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue
}
@@ -791,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
v1Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v1 query payload")
lh.l.Error("Failed to marshal lighthouse v1 query payload",
"error", err,
"queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -808,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
v2Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v2 query payload")
lh.l.Error("Failed to marshal lighthouse v2 query payload",
"error", err,
"queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -819,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried++
} else {
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v)
lh.l.Debug("unsupported protocol version",
"op", "query",
"queryVpnAddr", addr,
"version", v,
)
continue
}
}
@@ -848,11 +877,24 @@ func (lh *LightHouse) StartUpdateWorker() {
return
case <-clockSource.C:
continue
case <-lh.updateTrigger:
continue
}
}
}()
}
// TriggerUpdate requests an immediate lighthouse update. This is a non-blocking
// operation intended to be called after a handshake completes with a lighthouse,
// so the lighthouse has our current addresses without waiting for the next
// periodic update.
func (lh *LightHouse) TriggerUpdate() {
select {
case lh.updateTrigger <- struct{}{}:
default:
}
}
func (lh *LightHouse) SendUpdate() {
var v4 []*V4AddrPort
var v6 []*V6AddrPort
@@ -898,8 +940,9 @@ func (lh *LightHouse) SendUpdate() {
if v == cert.Version1 {
if v1Update == nil {
if !lh.myVpnNetworks[0].Addr().Is4() {
lh.l.WithField("lighthouseAddr", lhVpnAddr).
Warn("cannot update lighthouse using v1 protocol without an IPv4 address")
lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address",
"lighthouseAddr", lhVpnAddr,
)
continue
}
var relays []uint32
@@ -923,8 +966,10 @@ func (lh *LightHouse) SendUpdate() {
v1Update, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
Error("Error while marshaling for lighthouse v1 update")
lh.l.Error("Error while marshaling for lighthouse v1 update",
"error", err,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -950,8 +995,10 @@ func (lh *LightHouse) SendUpdate() {
v2Update, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
Error("Error while marshaling for lighthouse v2 update")
lh.l.Error("Error while marshaling for lighthouse v2 update",
"error", err,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -960,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() {
updated++
} else {
lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v)
lh.l.Debug("unsupported protocol version",
"op", "update",
"version", v,
)
continue
}
}
@@ -974,7 +1024,7 @@ type LightHouseHandler struct {
out []byte
pb []byte
meta *NebulaMeta
l *logrus.Logger
l *slog.Logger
}
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
@@ -1023,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet")
lhh.l.Error("Failed to unmarshal lighthouse packet",
"error", err,
"vpnAddrs", fromVpnAddrs,
"udpAddr", rAddr,
)
return
}
if n.Details == nil {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update")
lhh.l.Error("Invalid lighthouse update",
"vpnAddrs", fromVpnAddrs,
"udpAddr", rAddr,
)
return
}
@@ -1058,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I don't answer queries, but received from: ", addr)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("I don't answer queries, but received one", "from", addr)
}
return
}
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Dropping malformed HostQuery",
"from", fromVpnAddrs,
"details", n.Details,
)
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("invalid vpn addr for v1 handleHostQuery",
"vpnAddrs", fromVpnAddrs,
"queryVpnAddr", queryVpnAddr,
)
}
return
}
@@ -1101,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
}
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
lhh.l.Error("Failed to marshal lighthouse host query reply",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return
}
@@ -1129,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok {
whereToPunch = newDest
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("unable to punch to host, no addresses in common",
"to", crt.Networks(),
)
}
}
}
@@ -1156,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
}
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for")
lhh.l.Error("Failed to marshal lighthouse host was queried for",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return
}
@@ -1198,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
}
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("version", v).Debug("unsupported protocol version")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("unsupported protocol version",
"op", "coalesceAnswers",
"version", v,
)
}
}
}
@@ -1212,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Error("dropping malformed HostQueryReply",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
}
return
}
@@ -1238,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs)
}
return
}
@@ -1262,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Host sent invalid update",
"vpnAddrs", fromVpnAddrs,
"answer", detailsVpnAddr,
)
}
return
}
@@ -1285,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
switch useVersion {
case cert.Version1:
if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message",
"vpnAddrs", fromVpnAddrs,
)
return
}
vpnAddrB := fromVpnAddrs[0].As4()
@@ -1293,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
case cert.Version2:
// do nothing, we want to send a blank message
default:
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
lhh.l.Error("invalid protocol version", "useVersion", useVersion)
return
}
ln, err := n.MarshalTo(lhh.pb)
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack")
lhh.l.Error("Failed to marshal lighthouse host update ack",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return
}
@@ -1316,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("dropping invalid HostPunchNotification",
"details", n.Details,
"error", err,
)
}
return
}
@@ -1334,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
}()
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Punching",
"vpnPeer", vpnPeer,
"logVpnAddr", logVpnAddr,
)
}
}
@@ -1360,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
if lhh.lh.punchy.GetRespond() {
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Sending a nebula test packet",
"vpnAddr", detailsVpnAddr,
)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine

View File

@@ -1,7 +1,6 @@
package nebula
import (
"context"
"encoding/binary"
"fmt"
"net/netip"
@@ -42,14 +41,14 @@ func Test_lhStaticMapping(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
_, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
require.NoError(t, err)
lh2 := "10.128.0.3"
c = config.NewC(l)
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
_, err = NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
}
@@ -71,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
}
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
@@ -99,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
}
c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh, err := NewLightHouseFromConfig(b.Context(), l, c, cs, nil, nil)
require.NoError(b, err)
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
@@ -202,7 +201,7 @@ func TestLighthouse_Memory(t *testing.T) {
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
lh.ifce = &mockEncWriter{}
require.NoError(t, err)
lhh := lh.NewRequestHandler()
@@ -288,7 +287,7 @@ func TestLighthouse_reload(t *testing.T) {
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
require.NoError(t, err)
nc := map[string]any{
@@ -523,7 +522,7 @@ func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
@@ -589,7 +588,7 @@ func TestLighthouse_DeletesWork(t *testing.T) {
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}

View File

@@ -1,45 +0,0 @@
package nebula
import (
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
func configLogger(l *logrus.Logger, c *config.C) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return nil
}

233
logging/logger.go Normal file
View File

@@ -0,0 +1,233 @@
// Package logging wires the nebula runtime-reconfigurable slog handler used
// by nebula.Main and the nebula CLI binaries. Callers build a logger with
// NewLogger, then call ApplyConfig at startup and from a config reload
// callback to push logging.level, logging.format, and
// logging.disable_timestamp changes onto the logger without rebuilding it.
package logging
import (
"context"
"fmt"
"io"
"log/slog"
"strings"
"sync/atomic"
"time"
)
// Config is the subset of *config.C that ApplyConfig reads. Declaring it
// here keeps the logging package from depending on config directly, which
// would cycle through the shared test helpers (test.NewLogger imports
// logging, and config's tests import test). *config.C satisfies this
// interface structurally with no adapter.
type Config interface {
GetString(key, def string) string
GetBool(key string, def bool) bool
}
// LevelTrace is a custom slog level below Debug, used when logging.level is
// "trace". slog has no builtin trace level; the value is one step below
// slog.LevelDebug in slog's 4-point spacing.
const LevelTrace = slog.Level(-8)
// NewLogger returns a *slog.Logger whose level, format, and timestamp
// emission can be reconfigured at runtime via ApplyConfig and the SSH debug
// commands. The default configuration is info-level text output so log
// calls made before ApplyConfig runs still produce output. Timestamps
// follow slog's default RFC3339Nano format; set logging.disable_timestamp
// in config to suppress them.
//
// ApplyConfig and the SSH commands discover the reconfig surface via
// structural type-assertion on l.Handler(), so replacement implementations
// (tests, platform-specific sinks) need only implement the subset of
// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)}
// they care about. Callers that pass a plain *slog.Logger without these
// methods get a silent no-op; reconfiguration is always opt-in.
func NewLogger(w io.Writer) *slog.Logger {
return slog.New(NewHandler(w))
}
// NewHandler builds the *Handler that NewLogger wraps. Exported for
// platform-specific sinks (notably cmd/nebula-service/logs_windows.go)
// that want to wrap the handler with extra behavior, such as tagging each
// record with its Event Log severity, while still benefiting from all the
// level / format / timestamp / WithAttrs machinery implemented here.
func NewHandler(w io.Writer) *Handler {
root := &handlerRoot{}
root.level.Set(slog.LevelInfo)
opts := &slog.HandlerOptions{Level: &root.level}
return &Handler{
root: root,
text: slog.NewTextHandler(w, opts),
json: slog.NewJSONHandler(w, opts),
}
}
// handlerRoot carries the reconfiguration state shared by every logger
// derived from a NewHandler call. All fields are consulted on the log
// path and updated lock-free.
type handlerRoot struct {
level slog.LevelVar
disableTimestamp atomic.Bool
// jsonMode picks which of the pre-derived inner handlers Handler.Handle
// dispatches to. Flipping it propagates instantly to every derived logger
// without rebuilding or chain-replaying anything.
jsonMode atomic.Bool
}
// Handler is the slog.Handler returned by NewHandler. It holds two
// pre-derived slog handlers -- one text, one json -- both built from the
// same accumulated WithAttrs/WithGroup state. Handle picks which one to
// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes
// effect immediately across the whole process without having to rebuild
// any derived loggers.
type Handler struct {
root *handlerRoot
text slog.Handler
json slog.Handler
}
func (h *Handler) Enabled(_ context.Context, l slog.Level) bool {
return h.root.level.Level() <= l
}
func (h *Handler) Handle(ctx context.Context, r slog.Record) error {
if h.root.disableTimestamp.Load() {
r.Time = time.Time{}
}
if h.root.jsonMode.Load() {
return h.json.Handle(ctx, r)
}
return h.text.Handle(ctx, r)
}
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return h
}
return &Handler{
root: h.root,
text: h.text.WithAttrs(attrs),
json: h.json.WithAttrs(attrs),
}
}
func (h *Handler) WithGroup(name string) slog.Handler {
if name == "" {
return h
}
return &Handler{
root: h.root,
text: h.text.WithGroup(name),
json: h.json.WithGroup(name),
}
}
// SetLevel updates the effective log level. Propagates to every derived
// logger via the shared LevelVar.
func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) }
// GetLevel reports the current log level.
func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() }
// SetFormat flips the output format atomically. Valid formats are "text"
// and "json". Every derived logger sees the new format on its next Handle
// call; no rebuild or registration is required.
func (h *Handler) SetFormat(format string) error {
switch format {
case "text":
h.root.jsonMode.Store(false)
case "json":
h.root.jsonMode.Store(true)
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"})
}
return nil
}
// GetFormat reports the currently selected format name.
func (h *Handler) GetFormat() string {
if h.root.jsonMode.Load() {
return "json"
}
return "text"
}
// SetDisableTimestamp toggles whether Handle zeroes r.Time before
// dispatching (slog's builtin text/json handlers skip emitting the time
// attribute on a zero time).
func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) }
// ApplyConfig reads logging.level, logging.format, and (optionally)
// logging.disable_timestamp from c and applies them to l. The reconfig
// surface is discovered via structural type-assertion on l.Handler(), so
// foreign handlers silently opt out of whichever capabilities they do not
// implement.
//
// nebula.Main does NOT call this function on your behalf; callers that want
// config-driven log level / format / timestamp updates invoke it at
// startup and register it as a reload callback themselves. This keeps the
// library from mutating an embedder's logger without their say-so.
func ApplyConfig(l *slog.Logger, c Config) error {
h := l.Handler()
lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return err
}
if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok {
ls.SetLevel(lvl)
}
format := strings.ToLower(c.GetString("logging.format", "text"))
if fs, ok := h.(interface{ SetFormat(string) error }); ok {
if err := fs.SetFormat(format); err != nil {
return err
}
}
if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok {
ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false))
}
return nil
}
// ParseLevel converts a config-string level name ("trace", "debug", "info",
// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and
// "panic" are accepted for backwards compatibility with pre-slog configs
// and both map to slog.LevelError.
func ParseLevel(s string) (slog.Level, error) {
switch s {
case "trace":
return LevelTrace, nil
case "debug":
return slog.LevelDebug, nil
case "info":
return slog.LevelInfo, nil
case "warn", "warning":
return slog.LevelWarn, nil
case "error":
return slog.LevelError, nil
case "fatal", "panic":
return slog.LevelError, nil
default:
return 0, fmt.Errorf("not a valid logging level: %q", s)
}
}
// LevelName returns a human-readable name for a slog.Level matching the
// strings accepted by ParseLevel.
func LevelName(l slog.Level) string {
switch {
case l <= LevelTrace:
return "trace"
case l <= slog.LevelDebug:
return "debug"
case l <= slog.LevelInfo:
return "info"
case l <= slog.LevelWarn:
return "warn"
default:
return "error"
}
}

View File

@@ -0,0 +1,90 @@
package logging
import (
"context"
"io"
"log/slog"
"testing"
)
// BenchmarkLogger_* compare the handler returned by NewLogger against a
// stock slog text handler. The key thing we care about is the per-log
// cost on a logger that has been derived via .With(), because that is the
// shape subsystems store on their structs (HostInfo.logger(),
// lh.l.With("subsystem", ...), etc.) and call from hot paths.
func BenchmarkLogger_Stock_RootInfo(b *testing.B) {
l := slog.New(slog.DiscardHandler)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Nebula_RootInfo(b *testing.B) {
l := NewLogger(io.Discard)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) {
l := slog.New(slog.DiscardHandler).With(
"subsystem", "bench",
"localIndex", 1234,
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) {
l := NewLogger(io.Discard).With(
"subsystem", "bench",
"localIndex", 1234,
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
// Gated-off-path benchmarks: mimic the typical hot-path shape
// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below
// the active level. This is the dominant pattern in inside.go/outside.go and
// what we pay on every packet.
func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) {
l := slog.New(slog.DiscardHandler).With(
"subsystem", "bench",
"localIndex", 1234,
)
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if l.Enabled(ctx, slog.LevelDebug) {
l.Debug("hello", "i", i)
}
}
}
func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) {
l := NewLogger(io.Discard).With(
"subsystem", "bench",
"localIndex", 1234,
)
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if l.Enabled(ctx, slog.LevelDebug) {
l.Debug("hello", "i", i)
}
}
}

93
main.go
View File

@@ -3,13 +3,13 @@ package nebula
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"runtime/debug"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/sshd"
@@ -20,7 +20,7 @@ import (
type m = map[string]any
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background())
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
defer func() {
@@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
buildVersion = moduleVersion()
}
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
}
// Print the config if in test, the exit comes later
if configTest {
b, err := yaml.Marshal(c.Settings)
@@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
// Print the final config
l.Println(string(b))
l.Info(string(b))
}
err := configLogger(l, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
}
c.RegisterReloadCallback(func(c *config.C) {
err := configLogger(l, c)
if err != nil {
l.WithError(err).Error("Failed to configure the logger")
}
})
pki, err := NewPKIFromConfig(l, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
@@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
}
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
}
@@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err)
sshStart = nil
}
}
@@ -99,19 +82,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
routines = 1
}
if routines > 1 {
l.WithField("routines", routines).Info("Using multiple routines")
l.Info("Using multiple routines", "routines", routines)
}
} else {
// deprecated and undocumented
tunQueues := c.GetInt("tun.routines", 1)
udpQueues := c.GetInt("listen.routines", 1)
if tunQueues > udpQueues {
routines = tunQueues
} else {
routines = udpQueues
}
routines = max(tunQueues, udpQueues)
if routines != 1 {
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines)
}
}
@@ -124,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
conntrackCacheTimeout = 1 * time.Second
}
if conntrackCacheTimeout > 0 {
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout)
}
var tun overlay.Device
@@ -170,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
@@ -205,27 +184,19 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
messageMetrics = newMessageMetricsOnlyRecvError()
}
useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false)
handshakeConfig := HandshakeConfig{
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
useRelays: useRelays,
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
messageMetrics: messageMetrics,
}
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
serveDns := false
if c.GetBool("lighthouse.serve_dns", false) {
if c.GetBool("lighthouse.am_lighthouse", false) {
serveDns = true
} else {
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
}
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
if err != nil {
l.Warn("Failed to start DNS responder", "error", err)
}
ifConfig := &InterfaceConfig{
@@ -234,7 +205,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Outside: udpConns[0],
pki: pki,
Firewall: fw,
ServeDns: serveDns,
DnsServer: ds,
HandshakeManager: handshakeManager,
connectionManager: connManager,
lightHouse: lightHouse,
@@ -304,7 +275,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
go handshakeManager.Run(ctx)
}
statsStart, err := startStats(l, c, buildVersion, configTest)
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
}
@@ -317,23 +288,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
attachCommands(l, c, ssh, ifce)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
var dnsStart func()
if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server")
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
}
return &Control{
ifce,
l,
ctx,
cancel,
sshStart,
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
connManager.Start,
state: StateReady,
f: ifce,
l: l,
ctx: ctx,
cancel: cancel,
sshStart: sshStart,
statsStart: stats.Start,
dnsStart: ds.Start,
lighthouseStart: lightHouse.StartUpdateWorker,
connectionManagerStart: connManager.Start,
}, nil
}

View File

@@ -13,6 +13,8 @@ type MessageMetrics struct {
rxUnknown metrics.Counter
txUnknown metrics.Counter
rxInvalid metrics.Counter
}
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
@@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int
}
}
}
func (m *MessageMetrics) RxInvalid(i int64) {
if m != nil && m.rxInvalid != nil {
m.rxInvalid.Inc(i)
}
}
func newMessageMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
@@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics {
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil),
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -60,29 +60,9 @@ message NebulaPing {
uint64 Time = 2;
}
message NebulaHandshake {
NebulaHandshakeDetails Details = 1;
bytes Hmac = 2;
}
message MultiPortDetails {
bool RxSupported = 1;
bool TxSupported = 2;
uint32 BasePort = 3;
uint32 TotalPorts = 4;
}
message NebulaHandshakeDetails {
bytes Cert = 1;
uint32 InitiatorIndex = 2;
uint32 ResponderIndex = 3;
uint64 Cookie = 4;
uint64 Time = 5;
uint32 CertVersion = 8;
MultiPortDetails InitiatorMultiPort = 6;
MultiPortDetails ResponderMultiPort = 7;
}
// NebulaHandshake / NebulaHandshakeDetails moved to
// handshake/handshake.proto. The handshake package speaks that wire format
// directly via a hand-written encoder/decoder.
message NebulaControl {
enum MessageType {

View File

@@ -15,14 +15,12 @@ type endianness interface {
var noiseEndianness endianness = binary.BigEndian
type NebulaCipherState struct {
c noise.Cipher
//k [32]byte
//n uint64
c cipher.AEAD
}
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
return &NebulaCipherState{c: s.Cipher()}
x := s.Cipher()
return &NebulaCipherState{c: x.(cipher.AEAD)}
}
// EncryptDanger encrypts and authenticates a given payload.
@@ -46,7 +44,7 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad)
out = s.c.Seal(out, nb, plaintext, ad)
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
return out, nil
} else {
@@ -61,7 +59,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64,
nb[2] = 0
nb[3] = 0
noiseEndianness.PutUint64(nb[4:], n)
return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad)
return s.c.Open(out, nb, ciphertext, ad)
} else {
return []byte{}, nil
}
@@ -69,7 +67,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64,
func (s *NebulaCipherState) Overhead() int {
if s != nil {
return s.c.(cipher.AEAD).Overhead()
return s.c.Overhead()
}
return 0
}

View File

@@ -1,5 +1,4 @@
//go:build !boringcrypto
// +build !boringcrypto
package nebula

View File

@@ -1,15 +1,16 @@
package nebula
import (
"context"
"encoding/binary"
"errors"
"log/slog"
"net/netip"
"time"
"github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/net/ipv4"
@@ -19,208 +20,239 @@ const (
minFwPacketLen = 4
)
var ErrOutOfWindow = errors.New("out of window packet")
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
// TODO: record metrics for rx holepunch/punchy packets?
if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while parsing inbound packet",
"from", via,
"error", err,
"packet", packet,
)
}
}
return
}
if h.Version != header.Version {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Unexpected header version received", "from", via)
}
return
}
// Check before processing to see if this is a expected type/subtype
if !h.IsValidSubType() {
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Unexpected packet received", "from", via)
}
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
f.messageMetrics.RxInvalid(1)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Refusing to process double encrypted packet", "from", via)
}
return
}
}
// don't keep Rx metrics for message type, since you can see those in the tun metrics
if h.Type != header.Message {
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
}
// Unencrypted packets
switch h.Type {
case header.Handshake:
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.handleRecvError(via.UdpAddr, h)
return
}
// Relay packets are special
isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay)
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
if isMessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
// At this point we should have a valid existing tunnel, verify and send
// recvError if necessary
if hostinfo == nil || hostinfo.ConnectionState == nil {
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
}
return
}
// All remaining packets are encrypted
ci := hostinfo.ConnectionState
if !ci.window.Check(f.l, h.MessageCounter) {
return
}
// Relay packets are special
if isMessageRelay {
f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache)
return
}
out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Failed to decrypt packet",
"error", err,
"from", via,
"header", h,
)
}
return
}
// Roam before we respond
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo)
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, via, h) {
return
}
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
return
}
case header.MessageRelay:
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
// which will gracefully fail in the DecryptDanger call.
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
return
}
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, via)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return
}
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
via = ViaSender{
UdpAddr: via.UdpAddr,
relayHI: hostinfo,
remoteIdx: relay.RemoteIndex,
relay: relay,
IsRelayed: true,
}
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return
}
// If that relay is Established, forward the payload through it
if targetRelay.State == Established {
switch targetRelay.Type {
case ForwardingType:
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
return
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return
}
}
f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h)
return
}
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
return
}
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f)
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
switch h.Subtype {
case header.TestReply:
// No-op, useful for the Roaming and connectionManager side-effects above
case header.TestRequest:
f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out)
default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h)
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt test packet")
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(via.UdpAddr, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
hostinfo.logger(f.l).WithField("from", via).
Info("Close tunnel received, tearing down.")
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt Control packet")
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
f.relayManager.HandleControlMsg(hostinfo, out, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h)
}
}
func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
// which will gracefully fail in the DecryptDanger call.
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
return
}
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, via)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
"vpnAddrs", hostinfo.vpnAddrs,
"remoteIndex", h.RemoteIndex,
)
return
}
f.handleHostRoaming(hostinfo, via)
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
via = ViaSender{
UdpAddr: via.UdpAddr,
relayHI: hostinfo,
remoteIdx: relay.RemoteIndex,
relay: relay,
IsRelayed: true,
}
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
"relayTo", relay.PeerAddr,
"error", err,
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
)
return
}
f.connectionManager.In(hostinfo)
// If that relay is Established, forward the payload through it
if targetRelay.State == Established {
switch targetRelay.Type {
case ForwardingType:
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
return
default:
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type)
}
return
}
} else {
hostinfo.logger(f.l).Info("Unexpected target relay state",
"relayTo", relay.PeerAddr,
"relayFrom", hostinfo.vpnAddrs[0],
"targetRelayState", targetRelay.State,
)
return
}
default:
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type)
}
}
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -249,20 +281,27 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port())
}
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr)
}
return
}
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Info("Host roamed to new udp ip/port.")
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote",
"suppressSeconds", RoamingSuppressSeconds,
"udpAddr", hostinfo.remote,
"newAddr", via.UdpAddr,
)
}
return
}
hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.",
"udpAddr", hostinfo.remote,
"newAddr", via.UdpAddr,
)
hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(via.UdpAddr)
@@ -270,23 +309,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
}
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
}
return false
}
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
}
return true
}
var (
ErrPacketTooShort = errors.New("packet is too short")
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
@@ -336,13 +358,29 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
proto := layers.IPProtocol(data[protoAt])
switch proto {
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
case layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = false
return nil
case layers.IPProtocolICMPv6:
if dataLen < offset+6 {
return ErrIPv6PacketTooShort
}
fp.Protocol = uint8(proto)
fp.LocalPort = 0 //incoming vs outgoing doesn't matter for icmpv6
icmptype := data[offset+1]
switch icmptype {
case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply:
fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier
default:
fp.RemotePort = 0
}
fp.Fragment = false
return nil
case layers.IPProtocolTCP, layers.IPProtocolUDP:
if dataLen < offset+4 {
return ErrIPv6PacketTooShort
@@ -432,34 +470,38 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
minLen := ihl
if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
minLen += minFwPacketLen
if !fp.Fragment {
if fp.Protocol == firewall.ProtoICMP {
minLen += minFwPacketLen + 2
} else {
minLen += minFwPacketLen
}
}
if len(data) < minLen {
return ErrIPv4InvalidHeaderLength
}
// Firewall packets are locally oriented
if incoming {
if incoming { // Firewall packets are locally oriented
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
} else {
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
}
if fp.Fragment {
fp.RemotePort = 0
fp.LocalPort = 0
} else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
fp.LocalPort = 0 //code would be uint16(data[ihl+1])
} else if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
}
return nil
@@ -473,34 +515,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
}
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
hostinfo.logger(f.l).WithField("header", h).
Debugln("dropping out of window packet")
return nil, errors.New("out of window packet")
return nil, ErrOutOfWindow
}
return out, nil
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(out, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
return false
}
err = newPacket(out, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
return false
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
return false
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
"error", err,
"packet", out,
)
return
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
@@ -508,20 +536,19 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping inbound packet",
"fwPacket", fwPacket,
"reason", dropReason,
)
}
return false
return
}
f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
f.l.Error("Failed to write to tun", "error", err)
}
return true
}
func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
@@ -535,35 +562,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
_ = f.outside.WriteTo(b, endpoint)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index).
WithField("udpAddr", endpoint).
Debug("Recv error sent")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Recv error sent",
"index", index,
"udpAddr", endpoint,
)
}
}
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
Debug("Recv error received, ignoring")
f.l.Debug("Recv error received, ignoring",
"index", h.RemoteIndex,
"udpAddr", addr,
)
return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
Debug("Recv error received")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Recv error received",
"index", h.RemoteIndex,
"udpAddr", addr,
)
}
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if hostinfo == nil {
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex)
return
}
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
f.l.Info("Someone spoofing recv_errors?",
"addr", addr,
"hostinfoRemote", hostinfo.remote,
)
return
}

View File

@@ -155,6 +155,7 @@ func Test_newPacket_v6(t *testing.T) {
// next layer, missing length byte
err = newPacket(buffer.Bytes()[:49], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
err = nil
// A good ICMP packet
ip = layers.IPv6{
@@ -165,20 +166,26 @@ func Test_newPacket_v6(t *testing.T) {
DstIP: net.IPv6linklocalallnodes,
}
icmp := layers.ICMPv6{}
buffer.Clear()
err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
if err != nil {
panic(err)
icmp := layers.ICMPv6{
TypeCode: layers.ICMPv6TypeEchoRequest,
Checksum: 0x1234,
}
err = newPacket(buffer.Bytes(), true, p)
require.NoError(t, err)
buffer.Clear()
require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp))
require.Error(t, newPacket(buffer.Bytes(), true, p))
buffer.Clear()
echo := layers.ICMPv6Echo{
Identifier: 0xabcd,
SeqNumber: 1234,
}
require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp, &echo))
require.NoError(t, newPacket(buffer.Bytes(), true, p))
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0xabcd), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.False(t, p.Fragment)
@@ -574,7 +581,7 @@ func BenchmarkParseV6(b *testing.B) {
}
evilBytes := buffer.Bytes()
for i := 0; i < 200; i++ {
for range 200 {
evilBytes = append(evilBytes, hopHeader...)
}
evilBytes = append(evilBytes, lastHopHeader...)

View File

@@ -1,4 +1,6 @@
package test
// Package overlaytest provides fakes of overlay.Device for tests that do
// not want to touch a real tun device or route table.
package overlaytest
import (
"errors"
@@ -8,6 +10,9 @@ import (
"github.com/slackhq/nebula/routing"
)
// NoopTun is an overlay.Device that silently discards every read and write.
// Useful in tests that need to construct a nebula Interface but do not
// exercise the datapath.
type NoopTun struct{}
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {

View File

@@ -2,6 +2,7 @@ package overlay
import (
"fmt"
"log/slog"
"math"
"net"
"net/netip"
@@ -9,7 +10,6 @@ import (
"strconv"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
@@ -48,11 +48,14 @@ func (r Route) String() string {
return s
}
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
routeTree := new(bart.Table[routing.Gateways])
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
l.Warn("route MTU is not supported on this platform",
"goos", runtime.GOOS,
"route", r,
)
}
gateways := r.Via

View File

@@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) {
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true)
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
require.NoError(t, err)
ip, err := netip.ParseAddr("1.0.0.2")
@@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 3)
routeTree, err := makeRouteTree(l, routes, true)
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
require.NoError(t, err)
ip, err := netip.ParseAddr("192.168.86.1")

View File

@@ -2,20 +2,29 @@ package overlay
import (
"fmt"
"log/slog"
"net"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
)
const DefaultMTU = 1300
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
type NameError struct {
Name string
Underlying error
}
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func (e *NameError) Error() string {
return fmt.Sprintf("could not set tun device name: %s because %s", e.Name, e.Underlying)
}
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
@@ -27,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks)
}
}

View File

@@ -6,12 +6,12 @@ package overlay
import (
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -23,10 +23,10 @@ type tun struct {
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}

Some files were not shown because too many files have changed in this diff Show More