mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
Merge remote-tracking branch 'origin/master' into multiport
This commit is contained in:
5
.github/workflows/release.yml
vendored
5
.github/workflows/release.yml
vendored
@@ -209,10 +209,11 @@ jobs:
|
|||||||
id: create_release
|
id: create_release
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
GITHUB_REF_NAME: ${{ github.ref_name }}
|
||||||
run: |
|
run: |
|
||||||
cd artifacts
|
cd artifacts
|
||||||
gh release create \
|
gh release create \
|
||||||
--verify-tag \
|
--verify-tag \
|
||||||
--title "Release ${{ github.ref_name }}" \
|
--title "Release ${GITHUB_REF_NAME}" \
|
||||||
"${{ github.ref_name }}" \
|
"${GITHUB_REF_NAME}" \
|
||||||
SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz
|
SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz
|
||||||
|
|||||||
26
.github/workflows/smoke-extra.yml
vendored
26
.github/workflows/smoke-extra.yml
vendored
@@ -18,6 +18,8 @@ jobs:
|
|||||||
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
|
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
|
||||||
name: Run extra smoke tests
|
name: Run extra smoke tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
VAGRANT_DEFAULT_PROVIDER: libvirt
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
@@ -30,8 +32,13 @@ jobs:
|
|||||||
- name: add hashicorp source
|
- name: add hashicorp source
|
||||||
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
|
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
|
||||||
|
|
||||||
- name: install vagrant
|
- name: install vagrant and libvirt
|
||||||
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
|
run: |
|
||||||
|
sudo apt-get update && sudo apt-get install -y vagrant libvirt-daemon-system libvirt-dev
|
||||||
|
sudo chmod 666 /dev/kvm
|
||||||
|
sudo usermod -aG libvirt $(whoami)
|
||||||
|
sudo chmod 666 /var/run/libvirt/libvirt-sock
|
||||||
|
vagrant plugin install vagrant-libvirt
|
||||||
|
|
||||||
- name: freebsd-amd64
|
- name: freebsd-amd64
|
||||||
run: make smoke-vagrant/freebsd-amd64
|
run: make smoke-vagrant/freebsd-amd64
|
||||||
@@ -42,10 +49,19 @@ jobs:
|
|||||||
- name: netbsd-amd64
|
- name: netbsd-amd64
|
||||||
run: make smoke-vagrant/netbsd-amd64
|
run: make smoke-vagrant/netbsd-amd64
|
||||||
|
|
||||||
- name: linux-386
|
|
||||||
run: make smoke-vagrant/linux-386
|
|
||||||
|
|
||||||
- name: linux-amd64-ipv6disable
|
- name: linux-amd64-ipv6disable
|
||||||
run: make smoke-vagrant/linux-amd64-ipv6disable
|
run: make smoke-vagrant/linux-amd64-ipv6disable
|
||||||
|
|
||||||
|
# linux-386 runs last because it requires disabling KVM to use VirtualBox,
|
||||||
|
# which prevents libvirt (used by the other tests) from working after this point.
|
||||||
|
- name: install virtualbox for i386 test
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y virtualbox
|
||||||
|
sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true
|
||||||
|
|
||||||
|
- name: linux-386
|
||||||
|
env:
|
||||||
|
VAGRANT_DEFAULT_PROVIDER: virtualbox
|
||||||
|
run: make smoke-vagrant/linux-386
|
||||||
|
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
|||||||
8
.github/workflows/smoke/build-relay.sh
vendored
8
.github/workflows/smoke/build-relay.sh
vendored
@@ -16,8 +16,10 @@ relay:
|
|||||||
am_relay: true
|
am_relay: true
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
export LIGHTHOUSES="192.168.100.1 172.17.0.2:4242"
|
# TEST-NET-3 placeholder IPs; smoke-relay.sh seds them to real container IPs.
|
||||||
export REMOTE_ALLOW_LIST='{"172.17.0.4/32": false, "172.17.0.5/32": false}'
|
# Mapping: .2 lighthouse1, .3 host2, .4 host3, .5 host4.
|
||||||
|
export LIGHTHOUSES="192.168.100.1 203.0.113.2:4242"
|
||||||
|
export REMOTE_ALLOW_LIST='{"203.0.113.4/32": false, "203.0.113.5/32": false}'
|
||||||
|
|
||||||
HOST="host2" ../genconfig.sh >host2.yml <<EOF
|
HOST="host2" ../genconfig.sh >host2.yml <<EOF
|
||||||
relay:
|
relay:
|
||||||
@@ -25,7 +27,7 @@ relay:
|
|||||||
- 192.168.100.1
|
- 192.168.100.1
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
export REMOTE_ALLOW_LIST='{"172.17.0.3/32": false}'
|
export REMOTE_ALLOW_LIST='{"203.0.113.3/32": false}'
|
||||||
|
|
||||||
HOST="host3" ../genconfig.sh >host3.yml
|
HOST="host3" ../genconfig.sh >host3.yml
|
||||||
|
|
||||||
|
|||||||
18
.github/workflows/smoke/build.sh
vendored
18
.github/workflows/smoke/build.sh
vendored
@@ -5,9 +5,15 @@ set -e -x
|
|||||||
rm -rf ./build
|
rm -rf ./build
|
||||||
mkdir ./build
|
mkdir ./build
|
||||||
|
|
||||||
# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
|
# Smoke containers run on a dedicated docker network whose subnet is allocated
|
||||||
# - We could make this better by launching the lighthouse first and then fetching what IP it is.
|
# at smoke time, not known at build time. Configs are written with TEST-NET-3
|
||||||
NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
|
# placeholder IPs (RFC 5737) and smoke.sh / smoke-vagrant.sh / smoke-relay.sh
|
||||||
|
# sed the real container IPs in before starting nebula.
|
||||||
|
#
|
||||||
|
# Placeholder mapping (last octet == fixed container slot):
|
||||||
|
# 203.0.113.2 -> lighthouse1, 203.0.113.3 -> host2,
|
||||||
|
# 203.0.113.4 -> host3, 203.0.113.5 -> host4.
|
||||||
|
LIGHTHOUSE_IP="203.0.113.2"
|
||||||
|
|
||||||
(
|
(
|
||||||
cd build
|
cd build
|
||||||
@@ -25,16 +31,16 @@ NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{
|
|||||||
../genconfig.sh >lighthouse1.yml
|
../genconfig.sh >lighthouse1.yml
|
||||||
|
|
||||||
HOST="host2" \
|
HOST="host2" \
|
||||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
|
||||||
../genconfig.sh >host2.yml
|
../genconfig.sh >host2.yml
|
||||||
|
|
||||||
HOST="host3" \
|
HOST="host3" \
|
||||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
|
||||||
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
../genconfig.sh >host3.yml
|
../genconfig.sh >host3.yml
|
||||||
|
|
||||||
HOST="host4" \
|
HOST="host4" \
|
||||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
|
||||||
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
../genconfig.sh >host4.yml
|
../genconfig.sh >host4.yml
|
||||||
|
|
||||||
|
|||||||
57
.github/workflows/smoke/smoke-relay.sh
vendored
57
.github/workflows/smoke/smoke-relay.sh
vendored
@@ -6,6 +6,8 @@ set -o pipefail
|
|||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
|
|
||||||
|
NETWORK="nebula-smoke-relay"
|
||||||
|
|
||||||
cleanup() {
|
cleanup() {
|
||||||
echo
|
echo
|
||||||
echo " *** cleanup"
|
echo " *** cleanup"
|
||||||
@@ -16,22 +18,53 @@ cleanup() {
|
|||||||
then
|
then
|
||||||
docker kill lighthouse1 host2 host3 host4
|
docker kill lighthouse1 host2 host3 host4
|
||||||
fi
|
fi
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1
|
||||||
}
|
}
|
||||||
|
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
|
# Create a dedicated smoke network with an explicit subnet (required for --ip
|
||||||
docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
|
# below). Probe a short list of candidates so a locally-used range doesn't
|
||||||
docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
|
# fail the whole test — we only need one to be free.
|
||||||
docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
|
docker network rm "$NETWORK" >/dev/null 2>&1 || true
|
||||||
|
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
|
||||||
|
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
|
||||||
|
# .3 host2, .4 host3, .5 host4 — matches the placeholders in build-relay.sh.
|
||||||
|
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
|
||||||
|
PREFIX="${SUBNET%/*}"
|
||||||
|
PREFIX="${PREFIX%.*}"
|
||||||
|
LIGHTHOUSE_IP="$PREFIX.2"
|
||||||
|
HOST2_IP="$PREFIX.3"
|
||||||
|
HOST3_IP="$PREFIX.4"
|
||||||
|
HOST4_IP="$PREFIX.5"
|
||||||
|
|
||||||
|
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
|
||||||
|
for f in build/host2.yml build/host3.yml build/host4.yml; do
|
||||||
|
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
|
||||||
|
mv "$f.tmp" "$f"
|
||||||
|
done
|
||||||
|
|
||||||
|
docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
|
||||||
|
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" nebula:smoke-relay -config host2.yml -test
|
||||||
|
docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" nebula:smoke-relay -config host3.yml -test
|
||||||
|
docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" nebula:smoke-relay -config host4.yml -test
|
||||||
|
|
||||||
|
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
|
docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|
||||||
set +x
|
set +x
|
||||||
@@ -76,7 +109,13 @@ docker exec host4 sh -c 'kill 1'
|
|||||||
docker exec host3 sh -c 'kill 1'
|
docker exec host3 sh -c 'kill 1'
|
||||||
docker exec host2 sh -c 'kill 1'
|
docker exec host2 sh -c 'kill 1'
|
||||||
docker exec lighthouse1 sh -c 'kill 1'
|
docker exec lighthouse1 sh -c 'kill 1'
|
||||||
sleep 5
|
|
||||||
|
# Wait up to 30s for all backgrounded jobs to exit rather than relying on a
|
||||||
|
# fixed sleep.
|
||||||
|
for _ in $(seq 1 30); do
|
||||||
|
[ -z "$(jobs -r)" ] && break
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
if [ "$(jobs -r)" ]
|
if [ "$(jobs -r)" ]
|
||||||
then
|
then
|
||||||
|
|||||||
47
.github/workflows/smoke/smoke-vagrant.sh
vendored
47
.github/workflows/smoke/smoke-vagrant.sh
vendored
@@ -8,6 +8,8 @@ export VAGRANT_CWD="$PWD/vagrant-$1"
|
|||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
|
|
||||||
|
NETWORK="nebula-smoke"
|
||||||
|
|
||||||
cleanup() {
|
cleanup() {
|
||||||
echo
|
echo
|
||||||
echo " *** cleanup"
|
echo " *** cleanup"
|
||||||
@@ -19,21 +21,51 @@ cleanup() {
|
|||||||
docker kill lighthouse1 host2
|
docker kill lighthouse1 host2
|
||||||
fi
|
fi
|
||||||
vagrant destroy -f
|
vagrant destroy -f
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1
|
||||||
}
|
}
|
||||||
|
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
# Create a dedicated smoke network with an explicit subnet (required for --ip
|
||||||
|
# below). Probe a short list of candidates so a locally-used range doesn't
|
||||||
|
# fail the whole test — we only need one to be free.
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1 || true
|
||||||
|
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
|
||||||
|
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
|
||||||
|
# .3 host2 — matches the placeholders in build.sh.
|
||||||
|
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
|
||||||
|
PREFIX="${SUBNET%/*}"
|
||||||
|
PREFIX="${PREFIX%.*}"
|
||||||
|
LIGHTHOUSE_IP="$PREFIX.2"
|
||||||
|
HOST2_IP="$PREFIX.3"
|
||||||
|
|
||||||
|
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
|
||||||
|
# This must happen before `vagrant up` rsyncs build/ into the VM for host3.
|
||||||
|
for f in build/host2.yml build/host3.yml; do
|
||||||
|
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
|
||||||
|
mv "$f.tmp" "$f"
|
||||||
|
done
|
||||||
|
|
||||||
CONTAINER="nebula:${NAME:-smoke}"
|
CONTAINER="nebula:${NAME:-smoke}"
|
||||||
|
|
||||||
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
||||||
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
|
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test
|
||||||
|
|
||||||
vagrant up
|
vagrant up
|
||||||
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
|
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
|
||||||
|
|
||||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
||||||
sleep 15
|
sleep 15
|
||||||
@@ -96,7 +128,14 @@ vagrant ssh -c "ping -c1 192.168.100.2" -- -T
|
|||||||
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
|
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
|
||||||
docker exec host2 sh -c 'kill 1'
|
docker exec host2 sh -c 'kill 1'
|
||||||
docker exec lighthouse1 sh -c 'kill 1'
|
docker exec lighthouse1 sh -c 'kill 1'
|
||||||
sleep 1
|
|
||||||
|
# Wait up to 30s for all backgrounded jobs to exit. vagrant ssh in particular
|
||||||
|
# takes a beat to tear down after nebula exits on the VM, so a fixed sleep is
|
||||||
|
# racy.
|
||||||
|
for _ in $(seq 1 30); do
|
||||||
|
[ -z "$(jobs -r)" ] && break
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
if [ "$(jobs -r)" ]
|
if [ "$(jobs -r)" ]
|
||||||
then
|
then
|
||||||
|
|||||||
76
.github/workflows/smoke/smoke.sh
vendored
76
.github/workflows/smoke/smoke.sh
vendored
@@ -6,6 +6,8 @@ set -o pipefail
|
|||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
|
|
||||||
|
NETWORK="nebula-smoke"
|
||||||
|
|
||||||
cleanup() {
|
cleanup() {
|
||||||
echo
|
echo
|
||||||
echo " *** cleanup"
|
echo " *** cleanup"
|
||||||
@@ -16,38 +18,71 @@ cleanup() {
|
|||||||
then
|
then
|
||||||
docker kill lighthouse1 host2 host3 host4
|
docker kill lighthouse1 host2 host3 host4
|
||||||
fi
|
fi
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1
|
||||||
}
|
}
|
||||||
|
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
# Create a dedicated smoke network with an explicit subnet (required for --ip
|
||||||
|
# below). Probe a short list of candidates so a locally-used range doesn't
|
||||||
|
# fail the whole test — we only need one to be free.
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1 || true
|
||||||
|
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
|
||||||
|
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
|
||||||
|
# .3 host2, .4 host3, .5 host4 — matches the placeholders in build.sh.
|
||||||
|
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
|
||||||
|
PREFIX="${SUBNET%/*}"
|
||||||
|
PREFIX="${PREFIX%.*}"
|
||||||
|
LIGHTHOUSE_IP="$PREFIX.2"
|
||||||
|
HOST2_IP="$PREFIX.3"
|
||||||
|
HOST3_IP="$PREFIX.4"
|
||||||
|
HOST4_IP="$PREFIX.5"
|
||||||
|
|
||||||
|
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
|
||||||
|
# build/lighthouse1.yml has no IPs to rewrite so it's skipped.
|
||||||
|
for f in build/host2.yml build/host3.yml build/host4.yml; do
|
||||||
|
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
|
||||||
|
mv "$f.tmp" "$f"
|
||||||
|
done
|
||||||
|
|
||||||
CONTAINER="nebula:${NAME:-smoke}"
|
CONTAINER="nebula:${NAME:-smoke}"
|
||||||
|
|
||||||
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
||||||
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
|
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test
|
||||||
docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
|
docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" "$CONTAINER" -config host3.yml -test
|
||||||
docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
|
docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" "$CONTAINER" -config host4.yml -test
|
||||||
|
|
||||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
|
docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|
||||||
# grab tcpdump pcaps for debugging
|
# grab tcpdump pcaps for debugging
|
||||||
docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
|
docker exec lighthouse1 tcpdump -i tun0 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
|
||||||
docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
|
docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
|
||||||
docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
|
docker exec host2 tcpdump -i tun0 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
|
||||||
docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
|
docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
|
||||||
docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
|
docker exec host3 tcpdump -i tun0 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
|
||||||
docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
|
docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
|
||||||
docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
|
docker exec host4 tcpdump -i tun0 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
|
||||||
docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
|
docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
|
||||||
|
|
||||||
docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
||||||
docker exec host3 ncat -nklv 0.0.0.0 2000 &
|
docker exec host3 ncat -nklv 0.0.0.0 2000 &
|
||||||
|
docker exec host4 ncat -e '/usr/bin/echo helloagainfromhost4' -nkluv 0.0.0.0 4000 &
|
||||||
docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
|
docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
|
||||||
docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
|
docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
|
||||||
|
|
||||||
@@ -119,17 +154,24 @@ echo
|
|||||||
echo " *** Testing conntrack"
|
echo " *** Testing conntrack"
|
||||||
echo
|
echo
|
||||||
set -x
|
set -x
|
||||||
# host2 can ping host3 now that host3 pinged it first
|
|
||||||
docker exec host2 ping -c1 192.168.100.3
|
# host4's outbound firewall only allows ICMP to the lighthouse, so host4
|
||||||
# host4 can ping host2 once conntrack established
|
# cannot initiate UDP to host2. Once host2 initiates a flow to host4:4000,
|
||||||
docker exec host2 ping -c1 192.168.100.4
|
# conntrack must let host4's listener reply on that flow. If it doesn't,
|
||||||
docker exec host4 ping -c1 192.168.100.2
|
# the echo back from host4 never reaches host2.
|
||||||
|
docker exec host2 sh -c "(/usr/bin/echo host2; sleep 2) | ncat -nuv 192.168.100.4 4000" | grep -q helloagainfromhost4
|
||||||
|
|
||||||
docker exec host4 sh -c 'kill 1'
|
docker exec host4 sh -c 'kill 1'
|
||||||
docker exec host3 sh -c 'kill 1'
|
docker exec host3 sh -c 'kill 1'
|
||||||
docker exec host2 sh -c 'kill 1'
|
docker exec host2 sh -c 'kill 1'
|
||||||
docker exec lighthouse1 sh -c 'kill 1'
|
docker exec lighthouse1 sh -c 'kill 1'
|
||||||
sleep 5
|
|
||||||
|
# Wait up to 30s for all backgrounded jobs to exit rather than relying on a
|
||||||
|
# fixed sleep.
|
||||||
|
for _ in $(seq 1 30); do
|
||||||
|
[ -z "$(jobs -r)" ] && break
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
if [ "$(jobs -r)" ]
|
if [ "$(jobs -r)" ]
|
||||||
then
|
then
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# -*- mode: ruby -*-
|
# -*- mode: ruby -*-
|
||||||
# vi: set ft=ruby :
|
# vi: set ft=ruby :
|
||||||
Vagrant.configure("2") do |config|
|
Vagrant.configure("2") do |config|
|
||||||
config.vm.box = "ubuntu/jammy64"
|
config.vm.box = "bento/ubuntu-24.04"
|
||||||
|
|
||||||
config.vm.synced_folder "../build", "/nebula"
|
config.vm.synced_folder "../build", "/nebula"
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# -*- mode: ruby -*-
|
# -*- mode: ruby -*-
|
||||||
# vi: set ft=ruby :
|
# vi: set ft=ruby :
|
||||||
Vagrant.configure("2") do |config|
|
Vagrant.configure("2") do |config|
|
||||||
config.vm.box = "generic/openbsd7"
|
config.vm.box = "DefinedNet/openbsd78"
|
||||||
|
|
||||||
config.vm.synced_folder "../build", "/nebula", type: "rsync"
|
config.vm.synced_folder "../build", "/nebula", type: "rsync"
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -2,7 +2,21 @@ version: "2"
|
|||||||
linters:
|
linters:
|
||||||
default: none
|
default: none
|
||||||
enable:
|
enable:
|
||||||
|
- sloglint
|
||||||
- testifylint
|
- testifylint
|
||||||
|
settings:
|
||||||
|
sloglint:
|
||||||
|
# Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and
|
||||||
|
# the package-level slog equivalents. Use l.Log(ctx, level, ...) for
|
||||||
|
# custom levels instead of LogAttrs when you can.
|
||||||
|
#
|
||||||
|
# LogAttrs is also flagged by this rule because it takes ...slog.Attr;
|
||||||
|
# the few legitimate sites (where attrs is built up as a []slog.Attr)
|
||||||
|
# carry a //nolint:sloglint with rationale.
|
||||||
|
kv-only: true
|
||||||
|
# no-mixed-args is on by default: forbids mixing kv and attrs in one call.
|
||||||
|
# discard-handler is on by default (since Go 1.24): suggests
|
||||||
|
# slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil).
|
||||||
exclusions:
|
exclusions:
|
||||||
generated: lax
|
generated: lax
|
||||||
presets:
|
presets:
|
||||||
|
|||||||
28
CHANGELOG.md
28
CHANGELOG.md
@@ -7,6 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [1.10.3] - 2026-02-06
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
- Fix an issue where blocklist bypass is possible when using curve P256 since the signature can have 2 valid representations.
|
||||||
|
Both fingerprint representations will be tested against the blocklist.
|
||||||
|
Any newly issued P256 based certificates will have their signature clamped to the low-s form.
|
||||||
|
Nebula will assert the low-s signature form when validating certificates in a future version. [GHSA-69x3-g4r3-p962](https://github.com/slackhq/nebula/security/advisories/GHSA-69x3-g4r3-p962)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Improve error reporting if nebula fails to start due to a tun device naming issue. (#1588)
|
||||||
|
|
||||||
|
## [1.10.2] - 2026-01-21
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Fix panic when using `use_system_route_table` that was introduced in v1.10.1. (#1580)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Fix some typos in comments. (#1582)
|
||||||
|
- Dependency updates. (#1581)
|
||||||
|
|
||||||
## [1.10.1] - 2026-01-16
|
## [1.10.1] - 2026-01-16
|
||||||
|
|
||||||
See the [v1.10.1](https://github.com/slackhq/nebula/milestone/26?closed=1) milestone for a complete list of changes.
|
See the [v1.10.1](https://github.com/slackhq/nebula/milestone/26?closed=1) milestone for a complete list of changes.
|
||||||
@@ -764,7 +788,9 @@ created.)
|
|||||||
|
|
||||||
- Initial public release.
|
- Initial public release.
|
||||||
|
|
||||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.1...HEAD
|
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.3...HEAD
|
||||||
|
[1.10.3]: https://github.com/slackhq/nebula/releases/tag/v1.10.3
|
||||||
|
[1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2
|
||||||
[1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1
|
[1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1
|
||||||
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
|
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
|
||||||
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
|
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
|
||||||
|
|||||||
1
CODEOWNERS
Normal file
1
CODEOWNERS
Normal file
@@ -0,0 +1 @@
|
|||||||
|
#ECCN:Open Source
|
||||||
@@ -57,7 +57,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
|
|||||||
docker pull nebulaoss/nebula
|
docker pull nebulaoss/nebula
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Mobile
|
#### Mobile ([source code](https://github.com/DefinedNet/mobile_nebula))
|
||||||
|
|
||||||
- [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200)
|
- [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200)
|
||||||
- [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1)
|
- [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1)
|
||||||
@@ -76,6 +76,8 @@ Nebula was created to provide a mechanism for groups of hosts to communicate sec
|
|||||||
|
|
||||||
## Getting started (quickly)
|
## Getting started (quickly)
|
||||||
|
|
||||||
|
**Don't want to manage your own PKI and lighthouses?** [Managed Nebula](https://www.defined.net/) from Defined Networking handles all of this for you.
|
||||||
|
|
||||||
To set up a Nebula network, you'll need:
|
To set up a Nebula network, you'll need:
|
||||||
|
|
||||||
#### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use.
|
#### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use.
|
||||||
|
|||||||
239
bits.go
239
bits.go
@@ -1,23 +1,43 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
mathbits "math/bits"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const bitsPerWord = 64
|
||||||
|
|
||||||
|
// Bits is a sliding-window anti-replay tracker. The window is stored as a
|
||||||
|
// circular bitmap packed into uint64 words (8x denser than a []bool), so a
|
||||||
|
// length-N window costs N/8 bytes. length must be a power of two.
|
||||||
type Bits struct {
|
type Bits struct {
|
||||||
length uint64
|
length uint64
|
||||||
|
lengthMask uint64
|
||||||
current uint64
|
current uint64
|
||||||
bits []bool
|
bits []uint64
|
||||||
lostCounter metrics.Counter
|
lostCounter metrics.Counter
|
||||||
dupeCounter metrics.Counter
|
dupeCounter metrics.Counter
|
||||||
outOfWindowCounter metrics.Counter
|
outOfWindowCounter metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBits(bits uint64) *Bits {
|
func NewBits(length uint64) *Bits {
|
||||||
|
if length == 0 || length&(length-1) != 0 {
|
||||||
|
panic(fmt.Sprintf("Bits length must be a power of two, got %d", length))
|
||||||
|
}
|
||||||
|
|
||||||
|
nWords := length / bitsPerWord
|
||||||
|
if nWords == 0 {
|
||||||
|
nWords = 1
|
||||||
|
}
|
||||||
b := &Bits{
|
b := &Bits{
|
||||||
length: bits,
|
length: length,
|
||||||
bits: make([]bool, bits, bits),
|
lengthMask: length - 1,
|
||||||
|
bits: make([]uint64, nWords),
|
||||||
current: 0,
|
current: 0,
|
||||||
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
|
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
|
||||||
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
|
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
|
||||||
@@ -25,88 +45,219 @@ func NewBits(bits uint64) *Bits {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// There is no counter value 0, mark it to avoid counting a lost packet later.
|
// There is no counter value 0, mark it to avoid counting a lost packet later.
|
||||||
b.bits[0] = true
|
b.bits[0] = 1
|
||||||
b.current = 0
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
func (b *Bits) get(i uint64) bool {
|
||||||
|
pos := i & b.lengthMask
|
||||||
|
//bit-shifting by 6 because i is a bit index, not a u64 index, and we need to find the u64 without bit in it
|
||||||
|
return b.bits[pos>>6]&(uint64(1)<<(pos&63)) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bits) set(i uint64) {
|
||||||
|
pos := i & b.lengthMask
|
||||||
|
b.bits[pos>>6] |= uint64(1) << (pos & 63)
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearRange clears `count` bits starting at circular position `startPos`
|
||||||
|
// (already masked to [0, length)) and returns how many of them were set
|
||||||
|
// before the clear. count must be in [1, length].
|
||||||
|
func (b *Bits) clearRange(startPos, count uint64) uint64 {
|
||||||
|
wasSet := uint64(0)
|
||||||
|
if count >= b.length {
|
||||||
|
for _, w := range b.bits {
|
||||||
|
wasSet += uint64(mathbits.OnesCount64(w))
|
||||||
|
}
|
||||||
|
clear(b.bits)
|
||||||
|
return wasSet
|
||||||
|
}
|
||||||
|
|
||||||
|
pos := startPos
|
||||||
|
remaining := count
|
||||||
|
|
||||||
|
// handle the potential partial word before pos becomes u64 aligned
|
||||||
|
word := pos >> 6
|
||||||
|
bit := pos & 63
|
||||||
|
take := uint64(64) - bit
|
||||||
|
if take > remaining {
|
||||||
|
take = remaining
|
||||||
|
}
|
||||||
|
if take > b.length-pos {
|
||||||
|
take = b.length - pos
|
||||||
|
}
|
||||||
|
var mask uint64
|
||||||
|
if take == 64 {
|
||||||
|
mask = math.MaxUint64
|
||||||
|
} else {
|
||||||
|
mask = ((uint64(1) << take) - 1) << bit
|
||||||
|
}
|
||||||
|
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
|
||||||
|
b.bits[word] &^= mask
|
||||||
|
remaining -= take
|
||||||
|
pos = (pos + take) & b.lengthMask
|
||||||
|
|
||||||
|
// Clear whole words, keeping track of the number of set bits
|
||||||
|
for remaining >= 64 {
|
||||||
|
word = pos >> 6
|
||||||
|
wasSet += uint64(mathbits.OnesCount64(b.bits[word]))
|
||||||
|
b.bits[word] = 0
|
||||||
|
remaining -= 64
|
||||||
|
pos = (pos + 64) & b.lengthMask
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the remaining partial word
|
||||||
|
if remaining > 0 {
|
||||||
|
word = pos >> 6
|
||||||
|
mask = (uint64(1) << remaining) - 1
|
||||||
|
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
|
||||||
|
b.bits[word] &^= mask
|
||||||
|
}
|
||||||
|
|
||||||
|
return wasSet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bits) strictlyWithinWindow(i uint64) bool {
|
||||||
|
// Handle the case where the window hasn't slid yet. This avoids u64 underflow.
|
||||||
|
inWarmup := b.current < b.length
|
||||||
|
if i < b.length && inWarmup {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, if the packet is in-window, see if we've seen it before
|
||||||
|
if i > b.current-b.length {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false //not within window!
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check returns true if i is within (or way out in front of) the window, and not a replay
|
||||||
|
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
|
||||||
// If i is the next number, return true.
|
// If i is the next number, return true.
|
||||||
if i > b.current {
|
if i > b.current {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// If i is within the window, check if it's been set already.
|
if b.strictlyWithinWindow(i) {
|
||||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
return !b.get(i)
|
||||||
return !b.bits[i%b.length]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
l.Debug("rejected a packet (top)", "current", b.current, "incoming", i)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
// Update has three branches:
|
||||||
// If i is the next number, return true and update current.
|
// - i == b.current+1: fast path; advance the cursor by one and lose-count
|
||||||
|
// the slot we just stomped (only past warmup; see the i > b.length guard
|
||||||
|
// below).
|
||||||
|
// - i > b.current+1: jump path; clear all slots between current and i
|
||||||
|
// (or up to a full window's worth, whichever is smaller) via clearRange,
|
||||||
|
// then mark i. Two arms here: a warmup arm that handles the very first
|
||||||
|
// window before the cursor has slid, and a steady-state arm that treats
|
||||||
|
// every cleared empty slot as a lost packet.
|
||||||
|
// - i <= b.current: in-window check for duplicates; out-of-window otherwise.
|
||||||
|
//
|
||||||
|
// NewBits seeds bits[0]=1 so counter 0 looks "received" — Update never
|
||||||
|
// clears that marker during warmup (clearRange skips position 0 when
|
||||||
|
// startPos=1), and once b.current >= b.length the marker is no longer
|
||||||
|
// consulted. The marker prevents a fictitious "lost" hit on the first real
|
||||||
|
// counter.
|
||||||
|
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
|
||||||
|
// Fast path: i is the next expected counter. Split out so the function
|
||||||
|
// stays small and avoids paying for the slow paths' slog argument-build
|
||||||
|
// stack frame on every call. The bit read/test/write is inlined to
|
||||||
|
// touch the backing word once.
|
||||||
if i == b.current+1 {
|
if i == b.current+1 {
|
||||||
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
pos := i & b.lengthMask
|
||||||
// The very first window can only be tracked as lost once we are on the 2nd window or greater
|
word := pos >> 6
|
||||||
if b.bits[i%b.length] == false && i > b.length {
|
mask := uint64(1) << (pos & 63)
|
||||||
|
w := b.bits[word]
|
||||||
|
if i > b.length && w&mask == 0 {
|
||||||
b.lostCounter.Inc(1)
|
b.lostCounter.Inc(1)
|
||||||
}
|
}
|
||||||
b.bits[i%b.length] = true
|
b.bits[word] = w | mask
|
||||||
b.current = i
|
b.current = i
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
return b.updateSlow(l, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateSlow handles jumps, in-window backfill, dupes, and out-of-window.
|
||||||
|
func (b *Bits) updateSlow(l *slog.Logger, i uint64) bool {
|
||||||
// If i is a jump, adjust the window, record lost, update current, and return true
|
// If i is a jump, adjust the window, record lost, update current, and return true
|
||||||
if i > b.current {
|
if i > b.current {
|
||||||
lost := int64(0)
|
end := i
|
||||||
// Zero out the bits between the current and the new counter value, limited by the window size,
|
if end > b.current+b.length {
|
||||||
// since the window is shifting
|
end = b.current + b.length
|
||||||
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
|
}
|
||||||
if b.bits[n%b.length] == false && n > b.length {
|
count := end - b.current
|
||||||
lost++
|
startPos := (b.current + 1) & b.lengthMask
|
||||||
|
|
||||||
|
var lost int64
|
||||||
|
if b.current >= b.length {
|
||||||
|
// Steady state: every cleared slot is past warmup, so any unset
|
||||||
|
// bit we evict is a lost packet from the previous cycle.
|
||||||
|
wasSet := b.clearRange(startPos, count)
|
||||||
|
lost = int64(count) - int64(wasSet)
|
||||||
|
} else {
|
||||||
|
// Warmup (the very first window). Some cleared slots represent
|
||||||
|
// packets <= length where eviction is not "lost" in the usual
|
||||||
|
// sense. This branch is taken at most once per connection so we
|
||||||
|
// don't bother optimizing it.
|
||||||
|
for n := b.current + 1; n <= end; n++ {
|
||||||
|
if !b.get(n) && n > b.length {
|
||||||
|
lost++
|
||||||
|
}
|
||||||
}
|
}
|
||||||
b.bits[n%b.length] = false
|
b.clearRange(startPos, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only record any skipped packets as a result of the window moving further than the window length
|
// Anything past the new window can never be backfilled, so it's lost.
|
||||||
// Any loss within the new window will be accounted for in future calls
|
if i > b.current+b.length {
|
||||||
lost += max(0, int64(i-b.current-b.length))
|
lost += int64(i - b.current - b.length)
|
||||||
|
}
|
||||||
b.lostCounter.Inc(lost)
|
b.lostCounter.Inc(lost)
|
||||||
|
|
||||||
b.bits[i%b.length] = true
|
b.set(i)
|
||||||
b.current = i
|
b.current = i
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// If i is within the current window but below the current counter,
|
// If i is within the current window but below the current counter, check to see if it's a duplicate
|
||||||
// Check to see if it's a duplicate
|
if b.strictlyWithinWindow(i) {
|
||||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
pos := i & b.lengthMask
|
||||||
if b.current == i || b.bits[i%b.length] == true {
|
word := pos >> 6
|
||||||
if l.Level >= logrus.DebugLevel {
|
mask := uint64(1) << (pos & 63)
|
||||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
w := b.bits[word]
|
||||||
Debug("Receive window")
|
if b.current == i || w&mask != 0 {
|
||||||
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
l.Debug("Receive window",
|
||||||
|
"accepted", false,
|
||||||
|
"currentCounter", b.current,
|
||||||
|
"incomingCounter", i,
|
||||||
|
"reason", "duplicate",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
b.dupeCounter.Inc(1)
|
b.dupeCounter.Inc(1)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
b.bits[i%b.length] = true
|
b.bits[word] = w | mask
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// In all other cases, fail and don't change current.
|
// In all other cases, fail and don't change current.
|
||||||
b.outOfWindowCounter.Inc(1)
|
b.outOfWindowCounter.Inc(1)
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("accepted", false).
|
l.Debug("Receive window",
|
||||||
WithField("currentCounter", b.current).
|
"accepted", false,
|
||||||
WithField("incomingCounter", i).
|
"currentCounter", b.current,
|
||||||
WithField("reason", "nonsense").
|
"incomingCounter", i,
|
||||||
Debug("Receive window")
|
"reason", "nonsense",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
407
bits_test.go
407
bits_test.go
@@ -7,61 +7,79 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// snapshot returns the bitmap as a []bool of length b.length, for readable
|
||||||
|
// test assertions against the now-packed []uint64 storage.
|
||||||
|
func (b *Bits) snapshot() []bool {
|
||||||
|
out := make([]bool, b.length)
|
||||||
|
for i := uint64(0); i < b.length; i++ {
|
||||||
|
out[i] = b.get(i)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBitsRequiresPowerOfTwo(t *testing.T) {
|
||||||
|
assert.Panics(t, func() { NewBits(10) })
|
||||||
|
assert.Panics(t, func() { NewBits(0) })
|
||||||
|
assert.NotPanics(t, func() { NewBits(1) })
|
||||||
|
assert.NotPanics(t, func() { NewBits(16) })
|
||||||
|
assert.NotPanics(t, func() { NewBits(1024) })
|
||||||
|
assert.NotPanics(t, func() { NewBits(16384) })
|
||||||
|
}
|
||||||
|
|
||||||
func TestBits(t *testing.T) {
|
func TestBits(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(16)
|
||||||
|
assert.EqualValues(t, 16, b.length)
|
||||||
// make sure it is the right size
|
|
||||||
assert.Len(t, b.bits, 10)
|
|
||||||
|
|
||||||
// This is initialized to zero - receive one. This should work.
|
// This is initialized to zero - receive one. This should work.
|
||||||
assert.True(t, b.Check(l, 1))
|
assert.True(t, b.Check(l, 1))
|
||||||
assert.True(t, b.Update(l, 1))
|
assert.True(t, b.Update(l, 1))
|
||||||
assert.EqualValues(t, 1, b.current)
|
assert.EqualValues(t, 1, b.current)
|
||||||
g := []bool{true, true, false, false, false, false, false, false, false, false}
|
g := []bool{true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.snapshot())
|
||||||
|
|
||||||
// Receive two
|
// Receive two
|
||||||
assert.True(t, b.Check(l, 2))
|
assert.True(t, b.Check(l, 2))
|
||||||
assert.True(t, b.Update(l, 2))
|
assert.True(t, b.Update(l, 2))
|
||||||
assert.EqualValues(t, 2, b.current)
|
assert.EqualValues(t, 2, b.current)
|
||||||
g = []bool{true, true, true, false, false, false, false, false, false, false}
|
g = []bool{true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.snapshot())
|
||||||
|
|
||||||
// Receive two again - it will fail
|
// Receive two again - it will fail
|
||||||
assert.False(t, b.Check(l, 2))
|
assert.False(t, b.Check(l, 2))
|
||||||
assert.False(t, b.Update(l, 2))
|
assert.False(t, b.Update(l, 2))
|
||||||
assert.EqualValues(t, 2, b.current)
|
assert.EqualValues(t, 2, b.current)
|
||||||
|
|
||||||
// Jump ahead to 15, which should clear everything and set the 6th element
|
// Jump ahead to 25, which clears the window and sets slot 25%16 = 9.
|
||||||
assert.True(t, b.Check(l, 15))
|
assert.True(t, b.Check(l, 25))
|
||||||
assert.True(t, b.Update(l, 15))
|
assert.True(t, b.Update(l, 25))
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 25, b.current)
|
||||||
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
g = []bool{false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.snapshot())
|
||||||
|
|
||||||
// Mark 14, which is allowed because it is in the window
|
// Mark 24, which is in window (current 25, length 16, window covers [10,25]).
|
||||||
assert.True(t, b.Check(l, 14))
|
assert.True(t, b.Check(l, 24))
|
||||||
assert.True(t, b.Update(l, 14))
|
assert.True(t, b.Update(l, 24))
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 25, b.current)
|
||||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.snapshot())
|
||||||
|
|
||||||
// Mark 5, which is not allowed because it is not in the window
|
// Mark 5, not allowed because 5 <= current-length (25-16=9).
|
||||||
assert.False(t, b.Check(l, 5))
|
assert.False(t, b.Check(l, 5))
|
||||||
assert.False(t, b.Update(l, 5))
|
assert.False(t, b.Update(l, 5))
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 25, b.current)
|
||||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.snapshot())
|
||||||
|
|
||||||
// make sure we handle wrapping around once to the current position
|
// Make sure we handle wrapping around once to the same slot. With
|
||||||
b = NewBits(10)
|
// length=16, packets 1 and 17 share slot 1.
|
||||||
|
b = NewBits(16)
|
||||||
assert.True(t, b.Update(l, 1))
|
assert.True(t, b.Update(l, 1))
|
||||||
assert.True(t, b.Update(l, 11))
|
assert.True(t, b.Update(l, 17))
|
||||||
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
|
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}, b.snapshot())
|
||||||
|
|
||||||
// Walk through a few windows in order
|
// Walk through a few windows in order
|
||||||
b = NewBits(10)
|
b = NewBits(16)
|
||||||
for i := uint64(1); i <= 100; i++ {
|
for i := uint64(1); i <= 100; i++ {
|
||||||
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
||||||
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
||||||
@@ -72,24 +90,31 @@ func TestBits(t *testing.T) {
|
|||||||
|
|
||||||
func TestBitsLargeJumps(t *testing.T) {
|
func TestBitsLargeJumps(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
|
||||||
|
// length=16. Update(55) from current=0:
|
||||||
|
// warmup, per-bit loop sees no n>16 with unset bits (slot 0 was set by
|
||||||
|
// NewBits and gets re-evaluated when n=16; n=16 is not strictly > 16),
|
||||||
|
// so the loop contributes 0. The jump exceeds the window so we record
|
||||||
|
// 55 - 0 - 16 = 39 packets fell out the back.
|
||||||
|
b := NewBits(16)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
|
assert.True(t, b.Update(l, 55))
|
||||||
|
assert.Equal(t, int64(39), b.lostCounter.Count())
|
||||||
|
|
||||||
b = NewBits(10)
|
// Update(100): clears 16 slots starting at slot 56%16=8. Only slot 7 (for
|
||||||
b.lostCounter.Clear()
|
// packet 55) was set, so 16 - 1 = 15 evicted slots had unset bits.
|
||||||
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
|
// Plus 100 - 55 - 16 = 29 packets fell past the window. Total 44.
|
||||||
assert.Equal(t, int64(45), b.lostCounter.Count())
|
assert.True(t, b.Update(l, 100))
|
||||||
|
assert.Equal(t, int64(39+44), b.lostCounter.Count())
|
||||||
|
|
||||||
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
|
// Update(200): same shape: 16 - 1 = 15 evicted unset, plus 200 - 100 - 16 = 84 past window. Total 99.
|
||||||
assert.Equal(t, int64(89), b.lostCounter.Count())
|
assert.True(t, b.Update(l, 200))
|
||||||
|
assert.Equal(t, int64(39+44+99), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
|
|
||||||
assert.Equal(t, int64(188), b.lostCounter.Count())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsDupeCounter(t *testing.T) {
|
func TestBitsDupeCounter(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(16)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
@@ -114,120 +139,117 @@ func TestBitsDupeCounter(t *testing.T) {
|
|||||||
|
|
||||||
func TestBitsOutOfWindowCounter(t *testing.T) {
|
func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(16)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
|
// Jump to 20 (warmup branch + 4 past-window packets).
|
||||||
assert.True(t, b.Update(l, 20))
|
assert.True(t, b.Update(l, 20))
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
assert.True(t, b.Update(l, 21))
|
// 9 single-step advances, each evicts a slot whose bit was cleared during
|
||||||
assert.True(t, b.Update(l, 22))
|
// the jump above and whose value was never seen, so each contributes 1
|
||||||
assert.True(t, b.Update(l, 23))
|
// to lostCounter.
|
||||||
assert.True(t, b.Update(l, 24))
|
for n := uint64(21); n <= 29; n++ {
|
||||||
assert.True(t, b.Update(l, 25))
|
assert.True(t, b.Update(l, n))
|
||||||
assert.True(t, b.Update(l, 26))
|
}
|
||||||
assert.True(t, b.Update(l, 27))
|
|
||||||
assert.True(t, b.Update(l, 28))
|
|
||||||
assert.True(t, b.Update(l, 29))
|
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
|
// 0 is below current-length (29-16=13) so it falls outside the window.
|
||||||
assert.False(t, b.Update(l, 0))
|
assert.False(t, b.Update(l, 0))
|
||||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
// 4 from the Update(20) jump + 9 from 21..29.
|
||||||
|
assert.Equal(t, int64(13), b.lostCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsLostCounter(t *testing.T) {
|
func TestBitsLostCounter(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(16)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
assert.True(t, b.Update(l, 20))
|
// Walk 20..29 like the original, just with a bigger window. Same
|
||||||
assert.True(t, b.Update(l, 21))
|
// reasoning as TestBitsOutOfWindowCounter: 4 past-window from Update(20),
|
||||||
assert.True(t, b.Update(l, 22))
|
// then 9 more from the unit advances.
|
||||||
assert.True(t, b.Update(l, 23))
|
for n := uint64(20); n <= 29; n++ {
|
||||||
assert.True(t, b.Update(l, 24))
|
assert.True(t, b.Update(l, n))
|
||||||
assert.True(t, b.Update(l, 25))
|
}
|
||||||
assert.True(t, b.Update(l, 26))
|
assert.Equal(t, int64(13), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 27))
|
|
||||||
assert.True(t, b.Update(l, 28))
|
|
||||||
assert.True(t, b.Update(l, 29))
|
|
||||||
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
b = NewBits(10)
|
b = NewBits(16)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
assert.True(t, b.Update(l, 9))
|
// Update(15) clears the warmup window (no lost), sets slot 15.
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
// 10 will set 0 index, 0 was already set, no lost packets
|
|
||||||
assert.True(t, b.Update(l, 10))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
|
|
||||||
assert.True(t, b.Update(l, 11))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
// Now let's fill in the window, should end up with 8 lost packets
|
|
||||||
assert.True(t, b.Update(l, 12))
|
|
||||||
assert.True(t, b.Update(l, 13))
|
|
||||||
assert.True(t, b.Update(l, 14))
|
|
||||||
assert.True(t, b.Update(l, 15))
|
assert.True(t, b.Update(l, 15))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
|
||||||
|
// Update(16): slot 0 was already set (NewBits seeded it), and 16 is not
|
||||||
|
// strictly > length, so nothing is recorded as lost.
|
||||||
assert.True(t, b.Update(l, 16))
|
assert.True(t, b.Update(l, 16))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
|
||||||
|
// Update(17): we jumped straight from 0 to 15, so slot 1 was cleared
|
||||||
|
// (and never re-set). 17 > 16 is past warmup, so packet 1 is recorded lost.
|
||||||
assert.True(t, b.Update(l, 17))
|
assert.True(t, b.Update(l, 17))
|
||||||
assert.True(t, b.Update(l, 18))
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 19))
|
|
||||||
assert.Equal(t, int64(8), b.lostCounter.Count())
|
|
||||||
|
|
||||||
// Jump ahead by a window size
|
// Fill in 18..30 in single steps. Each i evicts slot i%16. Slots 2..14
|
||||||
assert.True(t, b.Update(l, 29))
|
// were all cleared during Update(15), and we never re-set any of them,
|
||||||
assert.Equal(t, int64(8), b.lostCounter.Count())
|
// so each i in 18..30 is a fresh lost packet — 13 more.
|
||||||
// Now lets walk ahead normally through the window, the missed packets should fill in
|
for n := uint64(18); n <= 30; n++ {
|
||||||
assert.True(t, b.Update(l, 30))
|
assert.True(t, b.Update(l, n))
|
||||||
assert.True(t, b.Update(l, 31))
|
}
|
||||||
assert.True(t, b.Update(l, 32))
|
assert.Equal(t, int64(14), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 33))
|
|
||||||
assert.True(t, b.Update(l, 34))
|
|
||||||
assert.True(t, b.Update(l, 35))
|
|
||||||
assert.True(t, b.Update(l, 36))
|
|
||||||
assert.True(t, b.Update(l, 37))
|
|
||||||
assert.True(t, b.Update(l, 38))
|
|
||||||
// 39 packets tracked, 22 seen, 17 lost
|
|
||||||
assert.Equal(t, int64(17), b.lostCounter.Count())
|
|
||||||
|
|
||||||
// Jump ahead by 2 windows, should have recording 1 full window missing
|
// Jump ahead by exactly one window size.
|
||||||
assert.True(t, b.Update(l, 58))
|
assert.True(t, b.Update(l, 46))
|
||||||
assert.Equal(t, int64(27), b.lostCounter.Count())
|
// end = min(46, 30+16) = 46, count = 16, all slots cleared. Before the
|
||||||
// Now lets walk ahead normally through the window, the missed packets should fill in from this window
|
// jump every slot 0..15 had been set (Update(15), (16), (17), 18..30),
|
||||||
assert.True(t, b.Update(l, 59))
|
// so wasSet=16 and 46 == current+length means no past-window slack:
|
||||||
assert.True(t, b.Update(l, 60))
|
// lost contribution = 0.
|
||||||
assert.True(t, b.Update(l, 61))
|
assert.Equal(t, int64(14), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 62))
|
|
||||||
assert.True(t, b.Update(l, 63))
|
// Walk 47..55. The Update(46) jump cleared every slot, so only slot 14
|
||||||
assert.True(t, b.Update(l, 64))
|
// (for packet 46) is set when we start. Each subsequent unit step lands
|
||||||
assert.True(t, b.Update(l, 65))
|
// on a slot that was cleared and is past warmup, so it counts as lost.
|
||||||
assert.True(t, b.Update(l, 66))
|
// 9 more = 23.
|
||||||
assert.True(t, b.Update(l, 67))
|
for n := uint64(47); n <= 55; n++ {
|
||||||
// 68 packets tracked, 32 seen, 36 missed
|
assert.True(t, b.Update(l, n))
|
||||||
assert.Equal(t, int64(36), b.lostCounter.Count())
|
}
|
||||||
|
assert.Equal(t, int64(23), b.lostCounter.Count())
|
||||||
|
|
||||||
|
// Jump ahead by two windows: clears the window plus past-window loss.
|
||||||
|
assert.True(t, b.Update(l, 87))
|
||||||
|
// current=55, length=16. end = min(87, 71) = 71. count=16, all slots
|
||||||
|
// cleared. Slots set before the clear are slots 14,15,0..7 (10 total).
|
||||||
|
// Lost from clear = 16 - 10 = 6. Past window: 87 - 55 - 16 = 16. +22.
|
||||||
|
assert.Equal(t, int64(45), b.lostCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsLostCounterIssue1(t *testing.T) {
|
func TestBitsLostCounterIssue1(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(16)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
|
// Receive 4, backfill 1, then 9, 2, 3, 5, 6, 7 (skip 8), 10, 11, 14.
|
||||||
|
// Then jump to 25 — slot 25%16=9 is being evicted, but it had been set
|
||||||
|
// (we received packet 9), so no spurious lost increment. The original
|
||||||
|
// regression was about double-counting a missing packet when its slot
|
||||||
|
// got cleared on a jump. With the jump path now using clearRange's
|
||||||
|
// word-level wasSet count, the same semantics hold.
|
||||||
assert.True(t, b.Update(l, 4))
|
assert.True(t, b.Update(l, 4))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 1))
|
assert.True(t, b.Update(l, 1))
|
||||||
@@ -244,7 +266,7 @@ func TestBitsLostCounterIssue1(t *testing.T) {
|
|||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 7))
|
assert.True(t, b.Update(l, 7))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
// assert.True(t, b.Update(l, 8))
|
// Skip packet 8.
|
||||||
assert.True(t, b.Update(l, 10))
|
assert.True(t, b.Update(l, 10))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 11))
|
assert.True(t, b.Update(l, 11))
|
||||||
@@ -252,9 +274,23 @@ func TestBitsLostCounterIssue1(t *testing.T) {
|
|||||||
|
|
||||||
assert.True(t, b.Update(l, 14))
|
assert.True(t, b.Update(l, 14))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
|
|
||||||
assert.True(t, b.Update(l, 19))
|
// Jump to 25. With length=16, slot 25%16=9 corresponds to packet 9
|
||||||
|
// (which we DID receive), so its bit is set and no lost++ from that
|
||||||
|
// eviction. The trace below shows the only loss is packet 8.
|
||||||
|
assert.True(t, b.Update(l, 25))
|
||||||
|
// current was 14, i=25. end=min(25,30)=25. count=11. startPos=15.
|
||||||
|
// steady? current=14<16, so warmup branch: per-bit n=15..25, count those
|
||||||
|
// with !get(n) AND n>16. n=17..25 are >16. Among slots 17%16=1..25%16=9
|
||||||
|
// did we set slots 1..9 (packets 1..9)? Yes for all but slot 8 (packet 8
|
||||||
|
// was skipped). n=24 maps to slot 8 which is FALSE → lost++. All other
|
||||||
|
// n in 17..25 map to slots that are set. n=16 is not strictly > 16. So
|
||||||
|
// lost = 1.
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
|
||||||
|
// Fill in 12, 13, 15, 16. Each is below current=25 (in-window). 16 must
|
||||||
|
// recheck slot 0 — it was set by NewBits and then cleared by the
|
||||||
|
// Update(25) jump, so 16 backfills cleanly.
|
||||||
assert.True(t, b.Update(l, 12))
|
assert.True(t, b.Update(l, 12))
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 13))
|
assert.True(t, b.Update(l, 13))
|
||||||
@@ -263,29 +299,140 @@ func TestBitsLostCounterIssue1(t *testing.T) {
|
|||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 16))
|
assert.True(t, b.Update(l, 16))
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
assert.True(t, b.Update(l, 17))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 18))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 20))
|
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 21))
|
|
||||||
|
|
||||||
// We missed packet 8 above
|
// We missed packet 8 above and that loss is still recorded once, never
|
||||||
|
// double-counted, never zeroed.
|
||||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkBits(b *testing.B) {
|
// TestBitsWarmupOvershoot exercises the jump path's warmup arm with an
|
||||||
z := NewBits(10)
|
// overshoot past one full window. NewBits leaves current=0 with only slot 0
|
||||||
for n := 0; n < b.N; n++ {
|
// "set" by the marker. Jumping straight to length+k must (a) clear every
|
||||||
for i := range z.bits {
|
// slot the jump straddles, (b) count only past-window slack (not the
|
||||||
z.bits[i] = true
|
// in-window slots, which never had a "lost" tenant during warmup), and
|
||||||
}
|
// (c) leave the cursor at the new counter so subsequent unit advances
|
||||||
for i := range z.bits {
|
// count from steady state. The marker bit at slot 0 is irrelevant once
|
||||||
z.bits[i] = false
|
// current >= length.
|
||||||
}
|
func TestBitsWarmupOvershoot(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
b := NewBits(16)
|
||||||
|
b.lostCounter.Clear()
|
||||||
|
|
||||||
|
// Jump from current=0 to i=20 (length=16, overshoot=4).
|
||||||
|
// Warmup arm: counts slots in [1..16] where bit unset and n>length.
|
||||||
|
// Only n=16 was unset and >length: but slot 16%16=0 is the marker,
|
||||||
|
// so b.get(16) reads bits[0]=1 and skips. Result: 0 lost from the loop.
|
||||||
|
// Past-window: i - current - length = 20 - 0 - 16 = 4 lost.
|
||||||
|
assert.True(t, b.Update(l, 20))
|
||||||
|
assert.Equal(t, int64(4), b.lostCounter.Count())
|
||||||
|
assert.Equal(t, uint64(20), b.current)
|
||||||
|
|
||||||
|
// Steady state now (current=20 >= length=16). Unit advance to 21
|
||||||
|
// stomps slot 21%16=5, which was cleared by the jump and not reset,
|
||||||
|
// so this is +1 lost.
|
||||||
|
assert.True(t, b.Update(l, 21))
|
||||||
|
assert.Equal(t, int64(5), b.lostCounter.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBitsCheckAcrossWarmupBoundary pins the underflow trick in Check's
|
||||||
|
// in-window clause. While in warmup, b.current-b.length underflows uint64
|
||||||
|
// to a huge value so the first OR-clause is always false; the second
|
||||||
|
// clause (i < length && current < length) carries the in-window check.
|
||||||
|
// Once current >= length the regimes flip cleanly.
|
||||||
|
func TestBitsCheckAcrossWarmupBoundary(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
b := NewBits(16)
|
||||||
|
|
||||||
|
// Warmup: current=0. Check(0) must read the marker (set) and return false.
|
||||||
|
assert.False(t, b.Check(l, 0), "marker slot should look already-received")
|
||||||
|
// Warmup: any 0 < i < length is in-window and unset → accepted.
|
||||||
|
for i := uint64(1); i < 16; i++ {
|
||||||
|
assert.True(t, b.Check(l, i), "warmup in-window i=%d should be accepted", i)
|
||||||
|
}
|
||||||
|
// Warmup: i >= length but > current is "next number" so accepted.
|
||||||
|
assert.True(t, b.Check(l, 16))
|
||||||
|
assert.True(t, b.Check(l, 1_000_000))
|
||||||
|
|
||||||
|
// Cross into steady state.
|
||||||
|
assert.True(t, b.Update(l, 100))
|
||||||
|
// Now current=100, length=16. In-window range is [85..100].
|
||||||
|
// 84 is just outside: the underflow clause activates; 84 > 100-16=84 is false.
|
||||||
|
// And the warmup clause is false (current >= length). So out of window.
|
||||||
|
assert.False(t, b.Check(l, 84))
|
||||||
|
// 85 sits at the boundary. 85 > 84 is true → in window, unset → accept.
|
||||||
|
assert.True(t, b.Check(l, 85))
|
||||||
|
// 100 is current itself; not strictly greater, in-window, but already set.
|
||||||
|
assert.False(t, b.Check(l, 100))
|
||||||
|
// Way out: clearly out of window.
|
||||||
|
assert.False(t, b.Check(l, 50))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBitsMarkerInvariant verifies the seeded bits[0]=1 marker behaves
|
||||||
|
// correctly across warmup and beyond. Update should never clear the marker
|
||||||
|
// during warmup (clearRange skips position 0 when startPos=1), and once
|
||||||
|
// current >= length the marker is no longer consulted by Check/Update on
|
||||||
|
// the live path — but it must still report counter 0 as a duplicate while
|
||||||
|
// we are in warmup.
|
||||||
|
func TestBitsMarkerInvariant(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
b := NewBits(8)
|
||||||
|
|
||||||
|
// Counter 0 is the seeded marker; Check sees it as already received.
|
||||||
|
assert.False(t, b.Check(l, 0))
|
||||||
|
// Update(0) at current=0 hits the duplicate branch.
|
||||||
|
b.dupeCounter.Clear()
|
||||||
|
assert.False(t, b.Update(l, 0))
|
||||||
|
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
||||||
|
|
||||||
|
// Walk forward through warmup; the marker must remain set.
|
||||||
|
for n := uint64(1); n <= 7; n++ {
|
||||||
|
assert.True(t, b.Update(l, n))
|
||||||
|
}
|
||||||
|
// Position 0 (the marker) should still read as set because we never
|
||||||
|
// cleared it; Update(0) still looks like a duplicate.
|
||||||
|
assert.False(t, b.Check(l, 0))
|
||||||
|
|
||||||
|
// Cross into steady state with a unit advance to 8: pos=0, evicts the
|
||||||
|
// marker bit. The lost-counter guard (i > b.length) is false (8 == 8),
|
||||||
|
// so this advance does NOT charge a lost packet — exactly what the
|
||||||
|
// marker is there to prevent.
|
||||||
|
b.lostCounter.Clear()
|
||||||
|
assert.True(t, b.Update(l, 8))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
// The slot at pos 0 is now occupied by counter 8.
|
||||||
|
assert.False(t, b.Check(l, 8))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkBitsUpdateInOrder is the steady-state hot path: each call is
|
||||||
|
// i == current+1.
|
||||||
|
func BenchmarkBitsUpdateInOrder(b *testing.B) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
z := NewBits(16384)
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
z.Update(l, uint64(n)+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkBitsUpdateReorder simulates light reorder within the window:
|
||||||
|
// every other packet arrives one slot behind its predecessor (forces the
|
||||||
|
// in-window backfill branch).
|
||||||
|
func BenchmarkBitsUpdateReorder(b *testing.B) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
z := NewBits(16384)
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
base := uint64(n) * 2
|
||||||
|
z.Update(l, base+2)
|
||||||
|
z.Update(l, base+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkBitsUpdateLargeJumps stresses the clearRange word-level path.
|
||||||
|
func BenchmarkBitsUpdateLargeJumps(b *testing.B) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
z := NewBits(16384)
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
z.Update(l, uint64(n+1)*1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build boringcrypto
|
//go:build boringcrypto
|
||||||
// +build boringcrypto
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,22 +32,46 @@ func NewCAPool() *CAPool {
|
|||||||
// If the pool contains any expired certificates, an ErrExpired will be
|
// If the pool contains any expired certificates, an ErrExpired will be
|
||||||
// returned along with the pool. The caller must handle any such errors.
|
// returned along with the pool. The caller must handle any such errors.
|
||||||
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
|
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
|
||||||
|
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
|
||||||
|
// The reader must contain a PEM-encoded set of nebula certificates.
|
||||||
|
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
|
||||||
pool := NewCAPool()
|
pool := NewCAPool()
|
||||||
var err error
|
|
||||||
var expired bool
|
var expired bool
|
||||||
for {
|
|
||||||
caPEMs, err = pool.AddCAFromPEM(caPEMs)
|
scanner := bufio.NewScanner(r)
|
||||||
if errors.Is(err, ErrExpired) {
|
scanner.Split(SplitPEM)
|
||||||
expired = true
|
|
||||||
err = nil
|
for scanner.Scan() {
|
||||||
|
pemBytes := scanner.Bytes()
|
||||||
|
|
||||||
|
block, rest := pem.Decode(pemBytes)
|
||||||
|
if len(bytes.TrimSpace(rest)) > 0 {
|
||||||
|
return nil, ErrInvalidPEMBlock
|
||||||
}
|
}
|
||||||
|
if block == nil {
|
||||||
|
return nil, ErrInvalidPEMBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := unmarshalCertificateBlock(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
|
|
||||||
break
|
err = pool.AddCA(c)
|
||||||
|
if errors.Is(err, ErrExpired) {
|
||||||
|
expired = true
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, ErrInvalidPEMBlock
|
||||||
|
}
|
||||||
|
|
||||||
if expired {
|
if expired {
|
||||||
return pool, ErrExpired
|
return pool, ErrExpired
|
||||||
@@ -141,10 +168,23 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pre nebula v1.10.3 could generate signatures in either high or low s form and validation
|
||||||
|
// of signatures allowed for either. Nebula v1.10.3 and beyond clamps signature generation to low-s form
|
||||||
|
// but validation still allows for either. Since a change in the signature bytes affects the fingerprint, we
|
||||||
|
// need to test both forms until such a time comes that we enforce low-s form on signature validation.
|
||||||
|
fp2, err := CalculateAlternateFingerprint(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not calculate alternate fingerprint to verify: %w", err)
|
||||||
|
}
|
||||||
|
if fp2 != "" && ncp.IsBlocklisted(fp2) {
|
||||||
|
return nil, ErrBlockListed
|
||||||
|
}
|
||||||
|
|
||||||
cc := CachedCertificate{
|
cc := CachedCertificate{
|
||||||
Certificate: c,
|
Certificate: c,
|
||||||
InvertedGroups: make(map[string]struct{}),
|
InvertedGroups: make(map[string]struct{}),
|
||||||
Fingerprint: fp,
|
Fingerprint: fp,
|
||||||
|
fingerprint2: fp2,
|
||||||
signerFingerprint: signer.Fingerprint,
|
signerFingerprint: signer.Fingerprint,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,6 +198,11 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti
|
|||||||
// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
|
// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
|
||||||
// is a cheaper operation to perform as a result.
|
// is a cheaper operation to perform as a result.
|
||||||
func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
|
func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
|
||||||
|
// Check any available alternate fingerprint forms for this certificate, re P256 high-s/low-s
|
||||||
|
if c.fingerprint2 != "" && ncp.IsBlocklisted(c.fingerprint2) {
|
||||||
|
return ErrBlockListed
|
||||||
|
}
|
||||||
|
|
||||||
_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
|
_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert/p256"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -111,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
|
|||||||
assert.Len(t, ppppp.CAs, 1)
|
assert.Len(t, ppppp.CAs, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oneByteReader wraps a reader to return at most 1 byte per Read call,
|
||||||
|
// exercising the streaming accumulation logic in NewCAPoolFromPEMReader.
|
||||||
|
type oneByteReader struct {
|
||||||
|
r io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *oneByteReader) Read(p []byte) (int, error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return o.r.Read(p[:1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) {
|
||||||
|
pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, pool.CAs)
|
||||||
|
|
||||||
|
pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n "))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, pool.CAs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) {
|
||||||
|
ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
|
ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
|
|
||||||
|
bundle := append(pem1, pem2...)
|
||||||
|
pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, pool.CAs, 2)
|
||||||
|
|
||||||
|
fp1, err := ca1.Fingerprint()
|
||||||
|
require.NoError(t, err)
|
||||||
|
fp2, err := ca2.Fingerprint()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Contains(t, pool.CAs, fp1)
|
||||||
|
assert.Contains(t, pool.CAs, fp2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) {
|
||||||
|
_, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata"))
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) {
|
||||||
|
_, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
|
|
||||||
|
bundle := append(pem1, []byte("some trailing garbage")...)
|
||||||
|
_, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle))
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCertificateV1_Verify(t *testing.T) {
|
func TestCertificateV1_Verify(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
@@ -170,6 +228,15 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.EqualError(t, err, "certificate is in the block list")
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
|
// Create a copy of the cert and swap to the alternate form for the signature
|
||||||
|
nc := c.Copy()
|
||||||
|
b, err := p256.Swap(c.Signature())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, nc.(*certificateV1).setSignature(b))
|
||||||
|
|
||||||
|
_, err = caPool.VerifyCertificate(time.Now(), nc)
|
||||||
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -187,7 +254,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err = caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
@@ -196,7 +263,17 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
cc, err := caPool.VerifyCertificate(time.Now(), c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reset the blocklist and block the alternate form fingerprint
|
||||||
|
caPool.ResetCertBlocklist()
|
||||||
|
caPool.BlocklistFingerprint(cc.fingerprint2)
|
||||||
|
err = caPool.VerifyCachedCertificate(time.Now(), cc)
|
||||||
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
|
caPool.ResetCertBlocklist()
|
||||||
|
err = caPool.VerifyCachedCertificate(time.Now(), cc)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -394,6 +471,15 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.EqualError(t, err, "certificate is in the block list")
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
|
// Create a copy of the cert and swap to the alternate form for the signature
|
||||||
|
nc := c.Copy()
|
||||||
|
b, err := p256.Swap(c.Signature())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, nc.(*certificateV2).setSignature(b))
|
||||||
|
|
||||||
|
_, err = caPool.VerifyCertificate(time.Now(), nc)
|
||||||
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -411,7 +497,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err = caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
@@ -420,7 +506,17 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
cc, err := caPool.VerifyCertificate(time.Now(), c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reset the blocklist and block the alternate form fingerprint
|
||||||
|
caPool.ResetCertBlocklist()
|
||||||
|
caPool.BlocklistFingerprint(cc.fingerprint2)
|
||||||
|
err = caPool.VerifyCachedCertificate(time.Now(), cc)
|
||||||
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
|
caPool.ResetCertBlocklist()
|
||||||
|
err = caPool.VerifyCachedCertificate(time.Now(), cc)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
33
cert/cert.go
33
cert/cert.go
@@ -4,6 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert/p256"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Version uint8
|
type Version uint8
|
||||||
@@ -110,6 +112,9 @@ type CachedCertificate struct {
|
|||||||
InvertedGroups map[string]struct{}
|
InvertedGroups map[string]struct{}
|
||||||
Fingerprint string
|
Fingerprint string
|
||||||
signerFingerprint string
|
signerFingerprint string
|
||||||
|
|
||||||
|
// A place to store a 2nd fingerprint if the certificate could have one, such as with P256
|
||||||
|
fingerprint2 string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *CachedCertificate) String() string {
|
func (cc *CachedCertificate) String() string {
|
||||||
@@ -152,3 +157,31 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
|
|||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates
|
||||||
|
// CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step.
|
||||||
|
func CalculateAlternateFingerprint(c Certificate) (string, error) {
|
||||||
|
if c.Curve() != Curve_P256 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nc := c.Copy()
|
||||||
|
b, err := p256.Swap(nc.Signature())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := nc.(type) {
|
||||||
|
case *certificateV1:
|
||||||
|
err = v.setSignature(b)
|
||||||
|
case *certificateV2:
|
||||||
|
err = v.setSignature(b)
|
||||||
|
default:
|
||||||
|
return "", ErrUnknownVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return nc.Fingerprint()
|
||||||
|
}
|
||||||
|
|||||||
127
cert/p256/p256.go
Normal file
127
cert/p256/p256.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package p256
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/elliptic"
|
||||||
|
"errors"
|
||||||
|
"math/big"
|
||||||
|
|
||||||
|
"filippo.io/bigmod"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/cryptobyte"
|
||||||
|
"golang.org/x/crypto/cryptobyte/asn1"
|
||||||
|
)
|
||||||
|
|
||||||
|
var halfN = new(big.Int).Rsh(elliptic.P256().Params().N, 1)
|
||||||
|
var nMod *bigmod.Modulus
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
n, err := bigmod.NewModulus(elliptic.P256().Params().N.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
nMod = n
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsNormalized(sig []byte) (bool, error) {
|
||||||
|
r, s, err := parseSignature(sig)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return checkLowS(r, s), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkLowS(_, s []byte) bool {
|
||||||
|
bigS := new(big.Int).SetBytes(s)
|
||||||
|
// Check if S <= (N/2), because we want to include the midpoint in the set of low-s
|
||||||
|
return bigS.Cmp(halfN) <= 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func swap(r, s []byte) ([]byte, []byte, error) {
|
||||||
|
var err error
|
||||||
|
bigS, err := bigmod.NewNat().SetBytes(s, nMod)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
sNormalized := nMod.Nat().Sub(bigS, nMod)
|
||||||
|
|
||||||
|
result := sNormalized.Bytes(nMod)
|
||||||
|
for len(result) > 1 && result[0] == 0 {
|
||||||
|
result = result[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Normalize(sig []byte) ([]byte, error) {
|
||||||
|
r, s, err := parseSignature(sig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if checkLowS(r, s) {
|
||||||
|
return sig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newR, newS, err := swap(r, s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return encodeSignature(newR, newS)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap will change sig between its current form to the opposite high or low form.
|
||||||
|
func Swap(sig []byte) ([]byte, error) {
|
||||||
|
r, s, err := parseSignature(sig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newR, newS, err := swap(r, s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return encodeSignature(newR, newS)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSignature taken exactly from crypto/ecdsa/ecdsa.go
|
||||||
|
func parseSignature(sig []byte) (r, s []byte, err error) {
|
||||||
|
var inner cryptobyte.String
|
||||||
|
input := cryptobyte.String(sig)
|
||||||
|
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
|
||||||
|
!input.Empty() ||
|
||||||
|
!inner.ReadASN1Integer(&r) ||
|
||||||
|
!inner.ReadASN1Integer(&s) ||
|
||||||
|
!inner.Empty() {
|
||||||
|
return nil, nil, errors.New("invalid ASN.1")
|
||||||
|
}
|
||||||
|
return r, s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeSignature(r, s []byte) ([]byte, error) {
|
||||||
|
var b cryptobyte.Builder
|
||||||
|
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||||
|
addASN1IntBytes(b, r)
|
||||||
|
addASN1IntBytes(b, s)
|
||||||
|
})
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// addASN1IntBytes encodes in ASN.1 a positive integer represented as
|
||||||
|
// a big-endian byte slice with zero or more leading zeroes.
|
||||||
|
func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
|
||||||
|
for len(bytes) > 0 && bytes[0] == 0 {
|
||||||
|
bytes = bytes[1:]
|
||||||
|
}
|
||||||
|
if len(bytes) == 0 {
|
||||||
|
b.SetError(errors.New("invalid integer"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) {
|
||||||
|
if bytes[0]&0x80 != 0 {
|
||||||
|
c.AddUint8(0)
|
||||||
|
}
|
||||||
|
c.AddBytes(bytes)
|
||||||
|
})
|
||||||
|
}
|
||||||
28
cert/p256/p256_test.go
Normal file
28
cert/p256/p256_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package p256
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFlipping(t *testing.T) {
|
||||||
|
priv, err1 := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
|
||||||
|
out, err := ecdsa.SignASN1(rand.Reader, priv, []byte("big chungus"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r, s, err := parseSignature(out)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r, s1, err := swap(r, s)
|
||||||
|
require.NoError(t, err)
|
||||||
|
r, s2, err := swap(r, s1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, s, s2)
|
||||||
|
require.NotEqual(t, s, s1)
|
||||||
|
}
|
||||||
82
cert/pem.go
82
cert/pem.go
@@ -1,12 +1,66 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrTruncatedPEMBlock = errors.New("truncated PEM block")
|
||||||
|
|
||||||
|
// SplitPEM is a split function for bufio.Scanner that returns each PEM block.
|
||||||
|
func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
// Look for the start of a PEM block
|
||||||
|
start := bytes.Index(data, []byte("-----BEGIN "))
|
||||||
|
if start == -1 {
|
||||||
|
if atEOF && len(bytes.TrimSpace(data)) > 0 {
|
||||||
|
// Non-whitespace content with no PEM block
|
||||||
|
return 0, nil, ErrTruncatedPEMBlock
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), nil, nil
|
||||||
|
}
|
||||||
|
// Request more data
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for the end marker
|
||||||
|
endMarkerStart := bytes.Index(data[start:], []byte("-----END "))
|
||||||
|
if endMarkerStart == -1 {
|
||||||
|
if atEOF {
|
||||||
|
// Incomplete PEM block at EOF
|
||||||
|
return 0, nil, ErrTruncatedPEMBlock
|
||||||
|
}
|
||||||
|
// Need more data to find the end
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the actual end of the END line (after the newline)
|
||||||
|
endMarkerStart += start
|
||||||
|
endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n')
|
||||||
|
var end int
|
||||||
|
if endLineEnd == -1 {
|
||||||
|
if atEOF {
|
||||||
|
// END marker without newline at EOF - take it anyway
|
||||||
|
end = len(data)
|
||||||
|
} else {
|
||||||
|
// Need more data
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
end = endMarkerStart + endLineEnd + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the PEM block
|
||||||
|
pemBlock := data[start:end]
|
||||||
|
|
||||||
|
// Return the valid PEM block
|
||||||
|
return end, pemBlock, nil
|
||||||
|
}
|
||||||
|
|
||||||
const ( //cert banners
|
const ( //cert banners
|
||||||
CertificateBanner = "NEBULA CERTIFICATE"
|
CertificateBanner = "NEBULA CERTIFICATE"
|
||||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||||
@@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
return nil, r, ErrInvalidPEMBlock
|
return nil, r, ErrInvalidPEMBlock
|
||||||
}
|
}
|
||||||
|
|
||||||
var c Certificate
|
c, err := unmarshalCertificateBlock(p)
|
||||||
var err error
|
|
||||||
|
|
||||||
switch p.Type {
|
|
||||||
// Implementations must validate the resulting certificate contains valid information
|
|
||||||
case CertificateBanner:
|
|
||||||
c, err = unmarshalCertificateV1(p.Bytes, nil)
|
|
||||||
case CertificateV2Banner:
|
|
||||||
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
|
|
||||||
default:
|
|
||||||
return nil, r, ErrInvalidPEMCertificateBanner
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, r, err
|
return nil, r, err
|
||||||
}
|
}
|
||||||
@@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
|
||||||
|
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
|
||||||
|
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
|
||||||
|
switch block.Type {
|
||||||
|
// Implementations must validate the resulting certificate contains valid information
|
||||||
|
case CertificateBanner:
|
||||||
|
return unmarshalCertificateV1(block.Bytes, nil)
|
||||||
|
case CertificateV2Banner:
|
||||||
|
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
|
||||||
|
default:
|
||||||
|
return nil, ErrInvalidPEMCertificateBanner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
||||||
if c.IsCA() {
|
if c.IsCA() {
|
||||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||||
|
|||||||
@@ -1,12 +1,88 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func scanAll(t *testing.T, input string) ([]string, error) {
|
||||||
|
t.Helper()
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(input))
|
||||||
|
scanner.Split(SplitPEM)
|
||||||
|
var blocks []string
|
||||||
|
for scanner.Scan() {
|
||||||
|
blocks = append(blocks, scanner.Text())
|
||||||
|
}
|
||||||
|
return blocks, scanner.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_Single(t *testing.T) {
|
||||||
|
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n"
|
||||||
|
blocks, err := scanAll(t, input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, blocks, 1)
|
||||||
|
require.Equal(t, input, blocks[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_Multiple(t *testing.T) {
|
||||||
|
block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n"
|
||||||
|
block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
|
||||||
|
blocks, err := scanAll(t, block1+block2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, blocks, 2)
|
||||||
|
require.Equal(t, block1, blocks[0])
|
||||||
|
require.Equal(t, block2, blocks[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) {
|
||||||
|
input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
|
||||||
|
blocks, err := scanAll(t, input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, blocks, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_Empty(t *testing.T) {
|
||||||
|
blocks, err := scanAll(t, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_WhitespaceOnly(t *testing.T) {
|
||||||
|
blocks, err := scanAll(t, " \n\t\n ")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_TrailingGarbage(t *testing.T) {
|
||||||
|
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage"
|
||||||
|
blocks, err := scanAll(t, input)
|
||||||
|
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
|
||||||
|
require.Len(t, blocks, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_TruncatedBlock(t *testing.T) {
|
||||||
|
input := "-----BEGIN TEST-----\npartial data with no end"
|
||||||
|
_, err := scanAll(t, input)
|
||||||
|
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_NoEndNewline(t *testing.T) {
|
||||||
|
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----"
|
||||||
|
blocks, err := scanAll(t, input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, blocks, 1)
|
||||||
|
require.Equal(t, input, blocks[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_GarbageOnly(t *testing.T) {
|
||||||
|
_, err := scanAll(t, "this is not PEM data")
|
||||||
|
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
||||||
goodCert := []byte(`
|
goodCert := []byte(`
|
||||||
# A good cert
|
# A good cert
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert/p256"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TBSCertificate represents a certificate intended to be signed.
|
// TBSCertificate represents a certificate intended to be signed.
|
||||||
@@ -126,6 +128,13 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if curve == Curve_P256 {
|
||||||
|
sig, err = p256.Normalize(sig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = c.setSignature(sig)
|
err = c.setSignature(sig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert/p256"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -89,3 +90,48 @@ func TestCertificateV1_SignP256(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, uc)
|
assert.NotNil(t, uc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCertificate_SignP256_AlwaysNormalized(t *testing.T) {
|
||||||
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
|
pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
|
||||||
|
|
||||||
|
tbs := TBSCertificate{
|
||||||
|
Version: Version1,
|
||||||
|
Name: "testing",
|
||||||
|
Networks: []netip.Prefix{
|
||||||
|
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||||
|
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||||
|
},
|
||||||
|
UnsafeNetworks: []netip.Prefix{
|
||||||
|
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||||
|
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||||
|
},
|
||||||
|
Groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||||
|
NotBefore: before,
|
||||||
|
NotAfter: after,
|
||||||
|
PublicKey: pubKey,
|
||||||
|
IsCA: true,
|
||||||
|
Curve: Curve_P256,
|
||||||
|
}
|
||||||
|
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
|
||||||
|
rawPriv := priv.D.FillBytes(make([]byte, 32))
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
if i&1 == 1 {
|
||||||
|
tbs.Version = Version1
|
||||||
|
} else {
|
||||||
|
tbs.Version = Version2
|
||||||
|
}
|
||||||
|
c, err := tbs.Sign(nil, Curve_P256, rawPriv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.True(t, c.CheckSignature(pub))
|
||||||
|
normie, err := p256.IsNormalized(c.Signature())
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, normie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -163,3 +163,55 @@ func P256Keypair() ([]byte, []byte) {
|
|||||||
pubkey := privkey.PublicKey()
|
pubkey := privkey.PublicKey()
|
||||||
return pubkey.Bytes(), privkey.Bytes()
|
return pubkey.Bytes(), privkey.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DummyCert is a minimal cert.Certificate implementation for testing error paths.
|
||||||
|
type DummyCert struct {
|
||||||
|
Version_ cert.Version
|
||||||
|
Curve_ cert.Curve
|
||||||
|
Groups_ []string
|
||||||
|
IsCA_ bool
|
||||||
|
Issuer_ string
|
||||||
|
Name_ string
|
||||||
|
Networks_ []netip.Prefix
|
||||||
|
NotAfter_ time.Time
|
||||||
|
NotBefore_ time.Time
|
||||||
|
PublicKey_ []byte
|
||||||
|
Signature_ []byte
|
||||||
|
UnsafeNetworks_ []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DummyCert) Version() cert.Version { return d.Version_ }
|
||||||
|
func (d *DummyCert) Curve() cert.Curve { return d.Curve_ }
|
||||||
|
func (d *DummyCert) Groups() []string { return d.Groups_ }
|
||||||
|
func (d *DummyCert) IsCA() bool { return d.IsCA_ }
|
||||||
|
func (d *DummyCert) Issuer() string { return d.Issuer_ }
|
||||||
|
func (d *DummyCert) Name() string { return d.Name_ }
|
||||||
|
func (d *DummyCert) Networks() []netip.Prefix { return d.Networks_ }
|
||||||
|
func (d *DummyCert) NotAfter() time.Time { return d.NotAfter_ }
|
||||||
|
func (d *DummyCert) NotBefore() time.Time { return d.NotBefore_ }
|
||||||
|
func (d *DummyCert) PublicKey() []byte { return d.PublicKey_ }
|
||||||
|
func (d *DummyCert) Signature() []byte { return d.Signature_ }
|
||||||
|
func (d *DummyCert) UnsafeNetworks() []netip.Prefix { return d.UnsafeNetworks_ }
|
||||||
|
func (d *DummyCert) Fingerprint() (string, error) { return "", nil }
|
||||||
|
func (d *DummyCert) CheckSignature(key []byte) bool { return false }
|
||||||
|
func (d *DummyCert) MarshalForHandshakes() ([]byte, error) { return nil, nil }
|
||||||
|
func (d *DummyCert) MarshalPEM() ([]byte, error) { return nil, nil }
|
||||||
|
func (d *DummyCert) MarshalJSON() ([]byte, error) { return nil, nil }
|
||||||
|
func (d *DummyCert) Marshal() ([]byte, error) { return nil, nil }
|
||||||
|
func (d *DummyCert) String() string { return "dummy" }
|
||||||
|
func (d *DummyCert) Copy() cert.Certificate { return d }
|
||||||
|
func (d *DummyCert) VerifyPrivateKey(c cert.Curve, k []byte) error { return nil }
|
||||||
|
func (d *DummyCert) Expired(time.Time) bool { return false }
|
||||||
|
func (d *DummyCert) MarshalPublicKeyPEM() []byte { return nil }
|
||||||
|
func (d *DummyCert) PublicKeyPEM() []byte { return nil }
|
||||||
|
|
||||||
|
// NewTestCAPool creates a CAPool from the given CA certificates, panicking on error.
|
||||||
|
func NewTestCAPool(cas ...cert.Certificate) *cert.CAPool {
|
||||||
|
pool := cert.NewCAPool()
|
||||||
|
for _, ca := range cas {
|
||||||
|
if err := pool.AddCA(ca); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
@@ -40,21 +39,15 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rawCACert, err := os.ReadFile(*vf.caPath)
|
caFile, err := os.Open(*vf.caPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error while reading ca: %w", err)
|
return fmt.Errorf("error while reading ca: %w", err)
|
||||||
}
|
}
|
||||||
|
defer caFile.Close()
|
||||||
|
|
||||||
caPool := cert.NewCAPool()
|
caPool, err := cert.NewCAPoolFromPEMReader(caFile)
|
||||||
for {
|
if err != nil && !errors.Is(err, cert.ErrExpired) {
|
||||||
rawCACert, err = caPool.AddCAFromPEM(rawCACert)
|
return fmt.Errorf("error while adding ca cert to pool: %w", err)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while adding ca cert to pool: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rawCert, err := os.ReadFile(*vf.certPath)
|
rawCert, err := os.ReadFile(*vf.certPath)
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func Test_verify(t *testing.T) {
|
|||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
|
require.ErrorIs(t, err, cert.ErrInvalidPEMBlock)
|
||||||
|
|
||||||
// make a ca for later
|
// make a ca for later
|
||||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
|||||||
@@ -3,8 +3,15 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "github.com/sirupsen/logrus"
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
func HookLogger(l *logrus.Logger) {
|
"github.com/slackhq/nebula/logging"
|
||||||
// Do nothing, let the logs flow to stdout/stderr
|
)
|
||||||
|
|
||||||
|
// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows
|
||||||
|
// platforms have no special sink to integrate with.
|
||||||
|
func newPlatformLogger() *slog.Logger {
|
||||||
|
return logging.NewLogger(os.Stdout)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,54 +1,86 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"context"
|
||||||
"io/ioutil"
|
"log/slog"
|
||||||
"os"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer
|
// newPlatformLogger returns a *slog.Logger that routes every log record
|
||||||
// logrus output will be discarded
|
// through the Windows service logger so records end up in the Windows
|
||||||
func HookLogger(l *logrus.Logger) {
|
// Event Log. All the heavy lifting (level management, format swap,
|
||||||
l.AddHook(newLogHook(logger))
|
// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler;
|
||||||
l.SetOutput(ioutil.Discard)
|
// this file only contributes:
|
||||||
|
//
|
||||||
|
// - an io.Writer that forwards each formatted line to the service
|
||||||
|
// logger at the current record's Event Log severity, and
|
||||||
|
// - a thin severityTag that embeds *logging.Handler and overrides
|
||||||
|
// only Handle / WithAttrs / WithGroup, so Event Viewer's severity
|
||||||
|
// column and severity-based filters keep working the way they did
|
||||||
|
// before the slog migration.
|
||||||
|
//
|
||||||
|
// Format (text vs json) is carried by the embedded *logging.Handler, so
|
||||||
|
// logging.format: json in config still produces JSON lines in Event
|
||||||
|
// Viewer, same as the pre-slog logrus setup.
|
||||||
|
func newPlatformLogger() *slog.Logger {
|
||||||
|
w := &eventLogWriter{}
|
||||||
|
return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w})
|
||||||
}
|
}
|
||||||
|
|
||||||
type logHook struct {
|
// eventLogWriter forwards slog-formatted lines to the Windows service
|
||||||
sl service.Logger
|
// logger at the severity most recently stashed by severityTag.Handle.
|
||||||
|
// The mutex serializes the stash + inner.Handle + Write cycle per record
|
||||||
|
// across all concurrent goroutines; slog's builtin text/json handlers
|
||||||
|
// each hold their own mutex around Write, but that only protects the
|
||||||
|
// Write call itself, not our stash-then-handle sequence.
|
||||||
|
type eventLogWriter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
level slog.Level
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLogHook(sl service.Logger) *logHook {
|
func (w *eventLogWriter) Write(p []byte) (int, error) {
|
||||||
return &logHook{sl: sl}
|
line := strings.TrimRight(string(p), "\n")
|
||||||
}
|
switch {
|
||||||
|
case w.level >= slog.LevelError:
|
||||||
func (h *logHook) Fire(entry *logrus.Entry) error {
|
return len(p), logger.Error(line)
|
||||||
line, err := entry.String()
|
case w.level >= slog.LevelWarn:
|
||||||
if err != nil {
|
return len(p), logger.Warning(line)
|
||||||
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch entry.Level {
|
|
||||||
case logrus.PanicLevel:
|
|
||||||
return h.sl.Error(line)
|
|
||||||
case logrus.FatalLevel:
|
|
||||||
return h.sl.Error(line)
|
|
||||||
case logrus.ErrorLevel:
|
|
||||||
return h.sl.Error(line)
|
|
||||||
case logrus.WarnLevel:
|
|
||||||
return h.sl.Warning(line)
|
|
||||||
case logrus.InfoLevel:
|
|
||||||
return h.sl.Info(line)
|
|
||||||
case logrus.DebugLevel:
|
|
||||||
return h.sl.Info(line)
|
|
||||||
default:
|
default:
|
||||||
return nil
|
return len(p), logger.Info(line)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *logHook) Levels() []logrus.Level {
|
// severityTag embeds *logging.Handler to pick up everything it does for
|
||||||
return logrus.AllLevels
|
// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat,
|
||||||
|
// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup
|
||||||
|
// so each record's slog.Level is stashed on the writer before formatting
|
||||||
|
// and so derived handlers stay wrapped as severityTag rather than
|
||||||
|
// downgrading to bare *logging.Handler.
|
||||||
|
type severityTag struct {
|
||||||
|
*logging.Handler
|
||||||
|
w *eventLogWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *severityTag) Handle(ctx context.Context, r slog.Record) error {
|
||||||
|
s.w.mu.Lock()
|
||||||
|
defer s.w.mu.Unlock()
|
||||||
|
s.w.level = r.Level
|
||||||
|
return s.Handler.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||||
|
if len(attrs) == 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *severityTag) WithGroup(name string) slog.Handler {
|
||||||
|
if name == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,9 +50,14 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
l := logging.NewLogger(os.Stdout)
|
||||||
|
|
||||||
if *serviceFlag != "" {
|
if *serviceFlag != "" {
|
||||||
doService(configPath, configTest, Build, serviceFlag)
|
if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
|
||||||
os.Exit(1)
|
l.Error("Service command failed", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if *configPath == "" {
|
if *configPath == "" {
|
||||||
@@ -61,9 +66,6 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logrus.New()
|
|
||||||
l.Out = os.Stdout
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
err := c.Load(*configPath)
|
err := c.Load(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -71,6 +73,16 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
fmt.Printf("failed to apply logging config: %s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||||
@@ -78,8 +90,20 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
ctrl.ShutdownBlock()
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
|
|
||||||
|
if err := wait(); err != nil {
|
||||||
|
l.Error("Nebula stopped due to fatal error", "error", err)
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger service.Logger
|
var logger service.Logger
|
||||||
@@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error {
|
|||||||
// Start should not block.
|
// Start should not block.
|
||||||
logger.Info("Nebula service starting.")
|
logger.Info("Nebula service starting.")
|
||||||
|
|
||||||
l := logrus.New()
|
l := newPlatformLogger()
|
||||||
HookLogger(l)
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
err := c.Load(*p.configPath)
|
err := c.Load(*p.configPath)
|
||||||
@@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error {
|
|||||||
return fmt.Errorf("failed to load config: %s", err)
|
return fmt.Errorf("failed to load config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
return fmt.Errorf("failed to apply logging config: %s", err)
|
||||||
|
}
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
|
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -57,11 +65,11 @@ func fileExists(filename string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
|
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
|
||||||
if *configPath == "" {
|
if *configPath == "" {
|
||||||
ex, err := os.Executable()
|
ex, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return err
|
||||||
}
|
}
|
||||||
*configPath = filepath.Dir(ex) + "/config.yaml"
|
*configPath = filepath.Dir(ex) + "/config.yaml"
|
||||||
if !fileExists(*configPath) {
|
if !fileExists(*configPath) {
|
||||||
@@ -85,16 +93,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
|
|||||||
// Here are what the different loggers are doing:
|
// Here are what the different loggers are doing:
|
||||||
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
|
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
|
||||||
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
|
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
|
||||||
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use
|
// - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger
|
||||||
s, err := service.New(prg, svcConfig)
|
s, err := service.New(prg, svcConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
errs := make(chan error, 5)
|
errs := make(chan error, 5)
|
||||||
logger, err = s.Logger(errs)
|
logger, err = s.Logger(errs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -109,18 +117,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
|
|||||||
|
|
||||||
switch *serviceFlag {
|
switch *serviceFlag {
|
||||||
case "run":
|
case "run":
|
||||||
err = s.Run()
|
if err := s.Run(); err != nil {
|
||||||
if err != nil {
|
|
||||||
// Route any errors to the system logger
|
// Route any errors to the system logger
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
err := service.Control(s, *serviceFlag)
|
if err := service.Control(s, *serviceFlag); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Printf("Valid actions: %q\n", service.ControlAction)
|
log.Printf("Valid actions: %q\n", service.ControlAction)
|
||||||
log.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,8 +55,7 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logrus.New()
|
l := logging.NewLogger(os.Stdout)
|
||||||
l.Out = os.Stdout
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
err := c.Load(*configPath)
|
err := c.Load(*configPath)
|
||||||
@@ -65,6 +64,16 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
fmt.Printf("failed to apply logging config: %s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||||
@@ -72,9 +81,21 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
ctrl.ShutdownBlock()
|
|
||||||
|
if err := wait(); err != nil {
|
||||||
|
l.Error("Nebula stopped due to fatal error", "error", err)
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
|
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
|
||||||
@@ -13,30 +12,30 @@ import (
|
|||||||
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
|
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
|
||||||
const SdNotifyReady = "READY=1"
|
const SdNotifyReady = "READY=1"
|
||||||
|
|
||||||
func notifyReady(l *logrus.Logger) {
|
func notifyReady(l *slog.Logger) {
|
||||||
sockName := os.Getenv("NOTIFY_SOCKET")
|
sockName := os.Getenv("NOTIFY_SOCKET")
|
||||||
if sockName == "" {
|
if sockName == "" {
|
||||||
l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
|
l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
|
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("failed to connect to systemd notification socket")
|
l.Error("failed to connect to systemd notification socket", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
|
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("failed to set the write deadline for the systemd notification socket")
|
l.Error("failed to set the write deadline for the systemd notification socket", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
|
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
|
||||||
l.WithError(err).Error("failed to signal the systemd notification socket")
|
l.Error("failed to signal the systemd notification socket", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Debugln("notified systemd the service is ready")
|
l.Debug("notified systemd the service is ready")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "github.com/sirupsen/logrus"
|
import "log/slog"
|
||||||
|
|
||||||
func notifyReady(_ *logrus.Logger) {
|
func notifyReady(_ *slog.Logger) {
|
||||||
// No init service to notify
|
// No init service to notify
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -16,7 +17,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"go.yaml.in/yaml/v3"
|
"go.yaml.in/yaml/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,11 +26,11 @@ type C struct {
|
|||||||
Settings map[string]any
|
Settings map[string]any
|
||||||
oldSettings map[string]any
|
oldSettings map[string]any
|
||||||
callbacks []func(*C)
|
callbacks []func(*C)
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
reloadLock sync.Mutex
|
reloadLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewC(l *logrus.Logger) *C {
|
func NewC(l *slog.Logger) *C {
|
||||||
return &C{
|
return &C{
|
||||||
Settings: make(map[string]any),
|
Settings: make(map[string]any),
|
||||||
l: l,
|
l: l,
|
||||||
@@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool {
|
|||||||
|
|
||||||
newVals, err := yaml.Marshal(nv)
|
newVals, err := yaml.Marshal(nv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
c.l.Error("Error while marshaling new config",
|
||||||
|
"config_path", k,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
oldVals, err := yaml.Marshal(ov)
|
oldVals, err := yaml.Marshal(ov)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
c.l.Error("Error while marshaling old config",
|
||||||
|
"config_path", k,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(newVals) != string(oldVals)
|
return string(newVals) != string(oldVals)
|
||||||
@@ -154,7 +160,10 @@ func (c *C) ReloadConfig() {
|
|||||||
|
|
||||||
err := c.Load(c.path)
|
err := c.Load(c.path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
c.l.Error("Error occurred while reloading config",
|
||||||
|
"config_path", c.path,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@@ -47,10 +47,10 @@ type connectionManager struct {
|
|||||||
|
|
||||||
metricsTxPunchy metrics.Counter
|
metricsTxPunchy metrics.Counter
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
||||||
cm := &connectionManager{
|
cm := &connectionManager{
|
||||||
hostMap: hm,
|
hostMap: hm,
|
||||||
l: l,
|
l: l,
|
||||||
@@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
|
|||||||
old := cm.getInactivityTimeout()
|
old := cm.getInactivityTimeout()
|
||||||
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
||||||
if !initial {
|
if !initial {
|
||||||
cm.l.WithField("oldDuration", old).
|
cm.l.Info("Inactivity timeout has changed",
|
||||||
WithField("newDuration", cm.getInactivityTimeout()).
|
"oldDuration", old,
|
||||||
Info("Inactivity timeout has changed")
|
"newDuration", cm.getInactivityTimeout(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
|
|||||||
old := cm.dropInactive.Load()
|
old := cm.dropInactive.Load()
|
||||||
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
||||||
if !initial {
|
if !initial {
|
||||||
cm.l.WithField("oldBool", old).
|
cm.l.Info("Drop inactive setting has changed",
|
||||||
WithField("newBool", cm.dropInactive.Load()).
|
"oldBool", old,
|
||||||
Info("Drop inactive setting has changed")
|
"newBool", cm.dropInactive.Load(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
var err error
|
var err error
|
||||||
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
cm.l.Error("failed to migrate relay to new hostinfo", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
@@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
cm.l.Error("failed to marshal Control message to migrate relay", "error", err)
|
||||||
} else {
|
} else {
|
||||||
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
cm.l.WithFields(logrus.Fields{
|
cm.l.Info("send CreateRelayRequest",
|
||||||
"relayFrom": req.RelayFromAddr,
|
"relayFrom", req.RelayFromAddr,
|
||||||
"relayTo": req.RelayToAddr,
|
"relayTo", req.RelayToAddr,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": req.ResponderRelayIndex,
|
"responderRelayIndex", req.ResponderRelayIndex,
|
||||||
"vpnAddrs": newhostinfo.vpnAddrs}).
|
"vpnAddrs", newhostinfo.vpnAddrs,
|
||||||
Info("send CreateRelayRequest")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
hostinfo := cm.hostMap.Indexes[localIndex]
|
hostinfo := cm.hostMap.Indexes[localIndex]
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
|
cm.l.Debug("Not found in hostmap", "localIndex", localIndex)
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
// A hostinfo is determined alive if there is incoming traffic
|
// A hostinfo is determined alive if there is incoming traffic
|
||||||
if inTraffic {
|
if inTraffic {
|
||||||
decision := doNothing
|
decision := doNothing
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
"tunnelCheck", m{"state": "alive", "method": "passive"},
|
||||||
Debug("Tunnel status")
|
)
|
||||||
}
|
}
|
||||||
hostinfo.pendingDeletion.Store(false)
|
hostinfo.pendingDeletion.Store(false)
|
||||||
|
|
||||||
@@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
if hostinfo.pendingDeletion.Load() {
|
if hostinfo.pendingDeletion.Load() {
|
||||||
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Info("Tunnel status",
|
||||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
"tunnelCheck", m{"state": "dead", "method": "active"},
|
||||||
Info("Tunnel status")
|
)
|
||||||
|
|
||||||
return deleteTunnel, hostinfo, nil
|
return deleteTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
@@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
||||||
if isInactive {
|
if isInactive {
|
||||||
// Tunnel is inactive, tear it down
|
// Tunnel is inactive, tear it down
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity",
|
||||||
WithField("inactiveDuration", inactiveFor).
|
"inactiveDuration", inactiveFor,
|
||||||
WithField("primary", mainHostInfo).
|
"primary", mainHostInfo,
|
||||||
Info("Dropping tunnel due to inactivity")
|
)
|
||||||
|
|
||||||
return closeTunnel, hostinfo, primary
|
return closeTunnel, hostinfo, primary
|
||||||
}
|
}
|
||||||
@@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
cm.sendPunch(hostinfo)
|
cm.sendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
"tunnelCheck", m{"state": "testing", "method": "active"},
|
||||||
Debug("Tunnel status")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||||
decision = sendTestPacket
|
decision = sendTestPacket
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
|
hostinfo.logger(cm.l).Debug("Hostinfo sadness")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
|
|||||||
return false //cert is still valid! yay!
|
return false //cert is still valid! yay!
|
||||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||||
// Block listed certificates should always be disconnected
|
// Block listed certificates should always be disconnected
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel",
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
"error", err,
|
||||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
"fingerprint", remoteCert.Fingerprint,
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
} else if cm.intf.disconnectInvalid.Load() {
|
} else if cm.intf.disconnectInvalid.Load() {
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel",
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
"error", err,
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
"fingerprint", remoteCert.Fingerprint,
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
} else {
|
} else {
|
||||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||||
@@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
curCrtVersion := curCrt.Version()
|
curCrtVersion := curCrt.Version()
|
||||||
myCrt := cs.getCertificate(curCrtVersion)
|
myCrt := cs.getCertificate(curCrtVersion)
|
||||||
if myCrt == nil {
|
if myCrt == nil {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("version", curCrtVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("reason", "local certificate removed").
|
"version", curCrtVersion,
|
||||||
Info("Re-handshaking with remote")
|
"reason", "local certificate removed",
|
||||||
|
)
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("version", curCrtVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
"version", curCrtVersion,
|
||||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
"peerVersion", peerCrt.Certificate.Version(),
|
||||||
Info("Re-handshaking with remote")
|
"reason", "local certificate version lower than peer, attempting to correct",
|
||||||
|
)
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||||
})
|
})
|
||||||
@@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("reason", "local certificate is not current").
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Info("Re-handshaking with remote")
|
"reason", "local certificate is not current",
|
||||||
|
)
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if curCrtVersion < cs.initiatingVersion {
|
if curCrtVersion < cs.initiatingVersion {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Info("Re-handshaking with remote")
|
"reason", "current cert version < pki.initiatingVersion",
|
||||||
|
)
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/overlaytest"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -46,13 +46,13 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
initiatingVersion: cert.Version1,
|
initiatingVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1Credential: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -63,9 +63,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
@@ -79,7 +79,6 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
myCert: &dummyCert{version: cert.Version1},
|
myCert: &dummyCert{version: cert.Version1},
|
||||||
H: &noise.HandshakeState{},
|
|
||||||
}
|
}
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
@@ -129,13 +128,13 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
initiatingVersion: cert.Version1,
|
initiatingVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1Credential: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -146,9 +145,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
@@ -162,7 +161,6 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
myCert: &dummyCert{version: cert.Version1},
|
myCert: &dummyCert{version: cert.Version1},
|
||||||
H: &noise.HandshakeState{},
|
|
||||||
}
|
}
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
@@ -214,13 +212,13 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|||||||
initiatingVersion: cert.Version1,
|
initiatingVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1Credential: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -231,12 +229,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
conf.Settings["tunnels"] = map[string]any{
|
conf.Settings["tunnels"] = map[string]any{
|
||||||
"drop_inactive": true,
|
"drop_inactive": true,
|
||||||
}
|
}
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
assert.True(t, nc.dropInactive.Load())
|
assert.True(t, nc.dropInactive.Load())
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
|
|
||||||
@@ -248,7 +246,6 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
myCert: &dummyCert{version: cert.Version1},
|
myCert: &dummyCert{version: cert.Version1},
|
||||||
H: &noise.HandshakeState{},
|
|
||||||
}
|
}
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
@@ -339,15 +336,15 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
|
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{},
|
v1Cert: &dummyCert{},
|
||||||
v1HandshakeBytes: []byte{},
|
v1Credential: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -360,9 +357,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
ifce.disconnectInvalid.Store(true)
|
ifce.disconnectInvalid.Store(true)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
|
|
||||||
@@ -371,7 +368,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
myCert: &dummyCert{},
|
myCert: &dummyCert{},
|
||||||
peerCert: cachedPeerCert,
|
peerCert: cachedPeerCert,
|
||||||
H: &noise.HandshakeState{},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|||||||
@@ -1,16 +1,12 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/handshake"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ReplayWindow = 1024
|
const ReplayWindow = 1024
|
||||||
@@ -18,7 +14,6 @@ const ReplayWindow = 1024
|
|||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
eKey *NebulaCipherState
|
eKey *NebulaCipherState
|
||||||
dKey *NebulaCipherState
|
dKey *NebulaCipherState
|
||||||
H *noise.HandshakeState
|
|
||||||
myCert cert.Certificate
|
myCert cert.Certificate
|
||||||
peerCert *cert.CachedCertificate
|
peerCert *cert.CachedCertificate
|
||||||
initiator bool
|
initiator bool
|
||||||
@@ -27,55 +22,24 @@ type ConnectionState struct {
|
|||||||
writeLock sync.Mutex
|
writeLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
// newConnectionStateFromResult builds a fully-populated ConnectionState from a
|
||||||
var dhFunc noise.DHFunc
|
// completed handshake.Result. It seeds messageCounter and the replay window so
|
||||||
switch crt.Curve() {
|
// that the post-handshake message indices already used on the wire don't count
|
||||||
case cert.Curve_CURVE25519:
|
// as missed traffic in the data plane.
|
||||||
dhFunc = noise.DH25519
|
func newConnectionStateFromResult(r *handshake.Result) *ConnectionState {
|
||||||
case cert.Curve_P256:
|
|
||||||
if cs.pkcs11Backed {
|
|
||||||
dhFunc = noiseutil.DHP256PKCS11
|
|
||||||
} else {
|
|
||||||
dhFunc = noiseutil.DHP256
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
|
|
||||||
}
|
|
||||||
|
|
||||||
var ncs noise.CipherSuite
|
|
||||||
if cs.cipher == "chachapoly" {
|
|
||||||
ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
|
||||||
} else {
|
|
||||||
ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
|
||||||
}
|
|
||||||
|
|
||||||
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
|
||||||
hs, err := noise.NewHandshakeState(noise.Config{
|
|
||||||
CipherSuite: ncs,
|
|
||||||
Random: rand.Reader,
|
|
||||||
Pattern: pattern,
|
|
||||||
Initiator: initiator,
|
|
||||||
StaticKeypair: static,
|
|
||||||
//NOTE: These should come from CertState (pki.go) when we finally implement it
|
|
||||||
PresharedKey: []byte{},
|
|
||||||
PresharedKeyPlacement: 0,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("NewConnectionState: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The queue and ready params prevent a counter race that would happen when
|
|
||||||
// sending stored packets and simultaneously accepting new traffic.
|
|
||||||
ci := &ConnectionState{
|
ci := &ConnectionState{
|
||||||
H: hs,
|
myCert: r.MyCert,
|
||||||
initiator: initiator,
|
initiator: r.Initiator,
|
||||||
|
peerCert: r.RemoteCert,
|
||||||
|
eKey: NewNebulaCipherState(r.EKey),
|
||||||
|
dKey: NewNebulaCipherState(r.DKey),
|
||||||
window: NewBits(ReplayWindow),
|
window: NewBits(ReplayWindow),
|
||||||
myCert: crt,
|
|
||||||
}
|
}
|
||||||
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
ci.messageCounter.Add(r.MessageIndex)
|
||||||
ci.messageCounter.Add(2)
|
for i := uint64(1); i <= r.MessageIndex; i++ {
|
||||||
|
ci.window.Update(nil, i)
|
||||||
return ci, nil
|
}
|
||||||
|
return ci
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
||||||
|
|||||||
114
connection_state_test.go
Normal file
114
connection_state_test.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/flynn/noise"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
ct "github.com/slackhq/nebula/cert_test"
|
||||||
|
"github.com/slackhq/nebula/handshake"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// runTestHandshake runs a complete IX handshake between two freshly-built
|
||||||
|
// peers and returns the initiator and responder Results. Used to produce
|
||||||
|
// real cipher states for tests that need to exercise post-handshake glue.
|
||||||
|
func runTestHandshake(t *testing.T) (initR, respR *handshake.Result) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
|
||||||
|
makeCreds := func(name string, networks []netip.Prefix) handshake.GetCredentialFunc {
|
||||||
|
c, _, rawKey, _ := ct.NewTestCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
|
||||||
|
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
|
||||||
|
)
|
||||||
|
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
hsBytes, err := c.MarshalForHandshakes()
|
||||||
|
require.NoError(t, err)
|
||||||
|
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||||
|
cred := handshake.NewCredential(c, hsBytes, priv, ncs)
|
||||||
|
return func(v cert.Version) *handshake.Credential {
|
||||||
|
if v == cert.Version2 {
|
||||||
|
return cred
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := func(c cert.Certificate) (*cert.CachedCertificate, error) {
|
||||||
|
return caPool.VerifyCertificate(time.Now(), c)
|
||||||
|
}
|
||||||
|
|
||||||
|
initCreds := makeCreds("initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
respCreds := makeCreds("responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||||
|
|
||||||
|
initM, err := handshake.NewMachine(
|
||||||
|
cert.Version2, initCreds, verifier,
|
||||||
|
func() (uint32, error) { return 1000, nil },
|
||||||
|
true, header.HandshakeIXPSK0,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
respM, err := handshake.NewMachine(
|
||||||
|
cert.Version2, respCreds, verifier,
|
||||||
|
func() (uint32, error) { return 2000, nil },
|
||||||
|
false, header.HandshakeIXPSK0,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg1, err := initM.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, respR, err := respM.ProcessPacket(nil, msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, respR)
|
||||||
|
|
||||||
|
_, initR, err = initM.ProcessPacket(nil, resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, initR)
|
||||||
|
|
||||||
|
return initR, respR
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionStateFromResult(t *testing.T) {
|
||||||
|
initR, respR := runTestHandshake(t)
|
||||||
|
|
||||||
|
t.Run("initiator", func(t *testing.T) {
|
||||||
|
ci := newConnectionStateFromResult(initR)
|
||||||
|
assert.True(t, ci.initiator)
|
||||||
|
assert.Equal(t, initR.MyCert, ci.myCert)
|
||||||
|
assert.Equal(t, initR.RemoteCert, ci.peerCert)
|
||||||
|
assert.NotNil(t, ci.eKey)
|
||||||
|
assert.NotNil(t, ci.dKey)
|
||||||
|
|
||||||
|
// IX has 2 handshake messages; the next data-plane send is counter=3.
|
||||||
|
assert.Equal(t, uint64(2), ci.messageCounter.Load(),
|
||||||
|
"messageCounter must equal Result.MessageIndex so the next send is N+1")
|
||||||
|
|
||||||
|
// Both handshake counters must be marked seen so they don't appear lost.
|
||||||
|
// Check returns false if an index has already been recorded.
|
||||||
|
assert.False(t, ci.window.Check(nil, 1), "counter 1 must already be seen")
|
||||||
|
assert.False(t, ci.window.Check(nil, 2), "counter 2 must already be seen")
|
||||||
|
// Counter 3 is the next data-plane message and must NOT be pre-marked.
|
||||||
|
assert.True(t, ci.window.Check(nil, 3), "counter 3 must not be pre-seeded")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("responder", func(t *testing.T) {
|
||||||
|
ci := newConnectionStateFromResult(respR)
|
||||||
|
assert.False(t, ci.initiator)
|
||||||
|
assert.Equal(t, respR.MyCert, ci.myCert)
|
||||||
|
assert.Equal(t, respR.RemoteCert, ci.peerCert)
|
||||||
|
assert.NotNil(t, ci.eKey)
|
||||||
|
assert.NotNil(t, ci.dKey)
|
||||||
|
assert.Equal(t, uint64(2), ci.messageCounter.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
92
control.go
92
control.go
@@ -2,17 +2,33 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RunState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateUnknown RunState = iota
|
||||||
|
StateReady
|
||||||
|
StateStarted
|
||||||
|
StateStopping
|
||||||
|
StateStopped
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrAlreadyStarted = errors.New("nebula is already started")
|
||||||
|
var ErrAlreadyStopped = errors.New("nebula cannot be restarted")
|
||||||
|
var ErrUnknownState = errors.New("nebula state is invalid")
|
||||||
|
|
||||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
@@ -26,8 +42,11 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
|
stateLock sync.Mutex
|
||||||
|
state RunState
|
||||||
|
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
sshStart func()
|
sshStart func()
|
||||||
@@ -49,10 +68,31 @@ type ControlHostInfo struct {
|
|||||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
// Start actually runs nebula, this is a nonblocking call.
|
||||||
func (c *Control) Start() {
|
// The returned function blocks until nebula has fully stopped and returns the
|
||||||
|
// first fatal reader error (if any). A nil error means nebula shut down
|
||||||
|
// gracefully; a non-nil error means a reader hit an unexpected failure that
|
||||||
|
// triggered the shutdown.
|
||||||
|
func (c *Control) Start() (func() error, error) {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
defer c.stateLock.Unlock()
|
||||||
|
switch c.state {
|
||||||
|
case StateReady:
|
||||||
|
//yay!
|
||||||
|
case StateStopped, StateStopping:
|
||||||
|
return nil, ErrAlreadyStopped
|
||||||
|
case StateStarted:
|
||||||
|
return nil, ErrAlreadyStarted
|
||||||
|
default:
|
||||||
|
return nil, ErrUnknownState
|
||||||
|
}
|
||||||
|
|
||||||
// Activate the interface
|
// Activate the interface
|
||||||
c.f.activate()
|
err := c.f.activate()
|
||||||
|
if err != nil {
|
||||||
|
c.state = StateStopped
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||||
if c.sshStart != nil {
|
if c.sshStart != nil {
|
||||||
@@ -71,25 +111,51 @@ func (c *Control) Start() {
|
|||||||
c.lighthouseStart()
|
c.lighthouseStart()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.f.triggerShutdown = c.Stop
|
||||||
|
|
||||||
// Start reading packets.
|
// Start reading packets.
|
||||||
c.f.run()
|
out, err := c.f.run()
|
||||||
|
if err != nil {
|
||||||
|
c.state = StateStopped
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.state = StateStarted
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Control) State() RunState {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
defer c.stateLock.Unlock()
|
||||||
|
return c.state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) Context() context.Context {
|
func (c *Control) Context() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
||||||
func (c *Control) Stop() {
|
func (c *Control) Stop() {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
if c.state != StateStarted {
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
// We are stopping or stopped already
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = StateStopping
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
|
||||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||||
// being created while we're shutting them all down.
|
// being created while we're shutting them all down.
|
||||||
c.cancel()
|
c.cancel()
|
||||||
|
|
||||||
c.CloseAllTunnels(false)
|
c.CloseAllTunnels(false)
|
||||||
if err := c.f.Close(); err != nil {
|
if err := c.f.Close(); err != nil {
|
||||||
c.l.WithError(err).Error("Close interface failed")
|
c.l.Error("Close interface failed", "error", err)
|
||||||
}
|
}
|
||||||
c.l.Info("Goodbye")
|
c.stateLock.Lock()
|
||||||
|
c.state = StateStopped
|
||||||
|
c.stateLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
@@ -100,7 +166,7 @@ func (c *Control) ShutdownBlock() {
|
|||||||
|
|
||||||
rawSig := <-sigChan
|
rawSig := <-sigChan
|
||||||
sig := rawSig.String()
|
sig := rawSig.String()
|
||||||
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
|
c.l.Info("Caught signal, shutting down", "signal", sig)
|
||||||
c.Stop()
|
c.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -237,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
|||||||
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
c.f.closeTunnel(h)
|
c.f.closeTunnel(h)
|
||||||
|
|
||||||
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
|
c.l.Debug("Sending close tunnel message",
|
||||||
Debug("Sending close tunnel message")
|
"vpnAddrs", h.vpnAddrs,
|
||||||
|
"udpAddr", h.remote,
|
||||||
|
)
|
||||||
closed++
|
closed++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -79,10 +78,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
}, &Interface{})
|
}, &Interface{})
|
||||||
|
|
||||||
c := Control{
|
c := Control{
|
||||||
|
state: StateReady,
|
||||||
f: &Interface{
|
f: &Interface{
|
||||||
hostMap: hm,
|
hostMap: hm,
|
||||||
},
|
},
|
||||||
l: logrus.New(),
|
l: test.NewLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
|
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
//go:build e2e_testing
|
//go:build e2e_testing
|
||||||
// +build e2e_testing
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
@@ -23,7 +20,9 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
pipeTo.InjectUDPPacket(p)
|
pipeTo.InjectUDPPacket(p)
|
||||||
if h.Type == msgType && h.Subtype == subType {
|
match := h.Type == msgType && h.Subtype == subType
|
||||||
|
p.Release()
|
||||||
|
if match {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -39,7 +38,9 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
pipeTo.InjectUDPPacket(p)
|
pipeTo.InjectUDPPacket(p)
|
||||||
if h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType {
|
match := h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType
|
||||||
|
p.Release()
|
||||||
|
if match {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -91,65 +92,15 @@ func (c *Control) GetTunTxChan() <-chan []byte {
|
|||||||
return c.f.inside.(*overlay.TestTun).TxPackets
|
return c.f.inside.(*overlay.TestTun).TxPackets
|
||||||
}
|
}
|
||||||
|
|
||||||
// InjectUDPPacket will inject a packet into the udp side of nebula
|
// InjectUDPPacket injects a packet into the udp side. We copy internally so the caller keeps ownership of p.
|
||||||
|
// The copy comes from the freelist so steady-state alloc is zero.
|
||||||
func (c *Control) InjectUDPPacket(p *udp.Packet) {
|
func (c *Control) InjectUDPPacket(p *udp.Packet) {
|
||||||
c.f.outside.(*udp.TesterConn).Send(p)
|
c.f.outside.(*udp.TesterConn).Send(p.Copy())
|
||||||
}
|
}
|
||||||
|
|
||||||
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
|
// InjectTunPacket pushes an IP packet onto the tun interface.
|
||||||
func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
|
func (c *Control) InjectTunPacket(packet []byte) {
|
||||||
serialize := make([]gopacket.SerializableLayer, 0)
|
c.f.inside.(*overlay.TestTun).Send(packet)
|
||||||
var netLayer gopacket.NetworkLayer
|
|
||||||
if toAddr.Is6() {
|
|
||||||
if !fromAddr.Is6() {
|
|
||||||
panic("Cant send ipv6 to ipv4")
|
|
||||||
}
|
|
||||||
ip := &layers.IPv6{
|
|
||||||
Version: 6,
|
|
||||||
NextHeader: layers.IPProtocolUDP,
|
|
||||||
SrcIP: fromAddr.Unmap().AsSlice(),
|
|
||||||
DstIP: toAddr.Unmap().AsSlice(),
|
|
||||||
}
|
|
||||||
serialize = append(serialize, ip)
|
|
||||||
netLayer = ip
|
|
||||||
} else {
|
|
||||||
if !fromAddr.Is4() {
|
|
||||||
panic("Cant send ipv4 to ipv6")
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := &layers.IPv4{
|
|
||||||
Version: 4,
|
|
||||||
TTL: 64,
|
|
||||||
Protocol: layers.IPProtocolUDP,
|
|
||||||
SrcIP: fromAddr.Unmap().AsSlice(),
|
|
||||||
DstIP: toAddr.Unmap().AsSlice(),
|
|
||||||
}
|
|
||||||
serialize = append(serialize, ip)
|
|
||||||
netLayer = ip
|
|
||||||
}
|
|
||||||
|
|
||||||
udp := layers.UDP{
|
|
||||||
SrcPort: layers.UDPPort(fromPort),
|
|
||||||
DstPort: layers.UDPPort(toPort),
|
|
||||||
}
|
|
||||||
err := udp.SetNetworkLayerForChecksum(netLayer)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer := gopacket.NewSerializeBuffer()
|
|
||||||
opt := gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: true,
|
|
||||||
FixLengths: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
serialize = append(serialize, &udp, gopacket.Payload(data))
|
|
||||||
err = gopacket.SerializeLayers(buffer, opt, serialize...)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetVpnAddrs() []netip.Addr {
|
func (c *Control) GetVpnAddrs() []netip.Addr {
|
||||||
|
|||||||
22
dist/wireshark/nebula.lua
vendored
22
dist/wireshark/nebula.lua
vendored
@@ -84,30 +84,24 @@ end
|
|||||||
|
|
||||||
function nebula.prefs_changed()
|
function nebula.prefs_changed()
|
||||||
if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then
|
if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then
|
||||||
-- Nothing changed, bail
|
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Remove our old dissector
|
-- Remove all existing registrations
|
||||||
DissectorTable.get("udp.port"):remove_all(nebula)
|
DissectorTable.get("udp.port"):remove_all(nebula)
|
||||||
|
|
||||||
if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then
|
if nebula.prefs.all_ports then
|
||||||
default_settings.all_port = nebula.prefs.all_ports
|
-- Register on every port for hole punch capture
|
||||||
|
|
||||||
for i=0, 65535 do
|
for i=0, 65535 do
|
||||||
DissectorTable.get("udp.port"):add(i, nebula)
|
DissectorTable.get("udp.port"):add(i, nebula)
|
||||||
end
|
end
|
||||||
|
else
|
||||||
-- no need to establish again on specific ports
|
-- Register on the configured port only
|
||||||
return
|
DissectorTable.get("udp.port"):add(nebula.prefs.port, nebula)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
default_settings.all_ports = nebula.prefs.all_ports
|
||||||
if default_settings.all_ports ~= nebula.prefs.all_ports then
|
default_settings.port = nebula.prefs.port
|
||||||
-- Add our new port dissector
|
|
||||||
default_settings.port = nebula.prefs.port
|
|
||||||
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
|
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
|
||||||
|
|||||||
299
dns_server.go
299
dns_server.go
@@ -1,63 +1,249 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This whole thing should be rewritten to use context
|
type dnsServer struct {
|
||||||
|
|
||||||
var dnsR *dnsRecords
|
|
||||||
var dnsServer *dns.Server
|
|
||||||
var dnsAddr string
|
|
||||||
|
|
||||||
type dnsRecords struct {
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
ctx context.Context
|
||||||
dnsMap4 map[string]netip.Addr
|
dnsMap4 map[string]netip.Addr
|
||||||
dnsMap6 map[string]netip.Addr
|
dnsMap6 map[string]netip.Addr
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
myVpnAddrsTable *bart.Lite
|
myVpnAddrsTable *bart.Lite
|
||||||
|
|
||||||
|
mux *dns.ServeMux
|
||||||
|
|
||||||
|
// enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`.
|
||||||
|
// Start, Add, and reload consult it so callers don't need to know the
|
||||||
|
// gating rules. When it toggles off via reload, accumulated records are
|
||||||
|
// cleared so a later re-enable starts with a fresh map populated from
|
||||||
|
// new handshakes.
|
||||||
|
enabled atomic.Bool
|
||||||
|
|
||||||
|
serverMu sync.Mutex
|
||||||
|
server *dns.Server
|
||||||
|
// started is closed once `server` has finished binding (or after
|
||||||
|
// ListenAndServe returns on a bind failure). Stop waits on it before
|
||||||
|
// calling Shutdown to avoid the miekg/dns "server not started" race
|
||||||
|
// where a Shutdown that arrives before bind completes is silently
|
||||||
|
// ignored, leaving the listener running forever.
|
||||||
|
started chan struct{}
|
||||||
|
addr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
// newDnsServerFromConfig builds a dnsServer, applies the initial config, and
|
||||||
return &dnsRecords{
|
// registers a reload callback. The reload callback is registered before the
|
||||||
|
// initial config is applied, so a SIGHUP can later enable, fix, or disable
|
||||||
|
// DNS even if the initial application failed.
|
||||||
|
//
|
||||||
|
// The dnsServer internally gates on `lighthouse.serve_dns &&
|
||||||
|
// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally,
|
||||||
|
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
|
||||||
|
// watcher that tears the listener down on nebula shutdown. The returned
|
||||||
|
// pointer is always non-nil, even on error.
|
||||||
|
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
|
||||||
|
ds := &dnsServer{
|
||||||
l: l,
|
l: l,
|
||||||
|
ctx: ctx,
|
||||||
dnsMap4: make(map[string]netip.Addr),
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
dnsMap6: make(map[string]netip.Addr),
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||||
}
|
}
|
||||||
|
ds.mux = dns.NewServeMux()
|
||||||
|
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
||||||
|
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := ds.reload(c, false); err != nil {
|
||||||
|
ds.l.Error("Failed to reload DNS responder from config", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := ds.reload(c, true); err != nil {
|
||||||
|
return ds, err
|
||||||
|
}
|
||||||
|
return ds, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
// reload applies the latest config and reconciles the running state with it:
|
||||||
|
// - enabled toggled on -> spawn a runner
|
||||||
|
// - enabled toggled off -> stop the runner
|
||||||
|
// - listen address changed (while running) -> restart on the new address
|
||||||
|
// - everything else -> no-op
|
||||||
|
//
|
||||||
|
// On the initial call it only records configuration; Control.Start is what
|
||||||
|
// launches the first runner via dnsStart.
|
||||||
|
func (d *dnsServer) reload(c *config.C, initial bool) error {
|
||||||
|
wantsDns := c.GetBool("lighthouse.serve_dns", false)
|
||||||
|
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||||
|
enabled := wantsDns && amLighthouse
|
||||||
|
newAddr := getDnsServerAddr(c)
|
||||||
|
|
||||||
|
d.serverMu.Lock()
|
||||||
|
running := d.server
|
||||||
|
runningStarted := d.started
|
||||||
|
sameAddr := d.addr == newAddr
|
||||||
|
d.addr = newAddr
|
||||||
|
d.enabled.Store(enabled)
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
|
||||||
|
if initial {
|
||||||
|
if wantsDns && !amLighthouse {
|
||||||
|
d.l.Warn("DNS server refusing to run because this host is not a lighthouse.")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
if running != nil {
|
||||||
|
d.Stop()
|
||||||
|
}
|
||||||
|
// Drop any records that accumulated while enabled; a later re-enable
|
||||||
|
// will repopulate from fresh handshakes.
|
||||||
|
d.clearRecords()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if running == nil {
|
||||||
|
// Was disabled (or never started); bring it up now.
|
||||||
|
go d.Start()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if sameAddr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d.shutdownServer(running, runningStarted, "reload")
|
||||||
|
// Old Start goroutine has now exited; bring up a fresh listener on the
|
||||||
|
// new address.
|
||||||
|
go d.Start()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdownServer waits for the server to finish binding (so Shutdown actually
|
||||||
|
// stops it rather than no-oping) and then shuts it down.
|
||||||
|
func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) {
|
||||||
|
if srv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if started != nil {
|
||||||
|
<-started
|
||||||
|
}
|
||||||
|
if err := srv.Shutdown(); err != nil {
|
||||||
|
d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start binds and serves the DNS responder. Blocks until Stop is called or
|
||||||
|
// the listener errors. Safe to call when DNS is disabled (returns
|
||||||
|
// immediately). This is what Control.dnsStart points at.
|
||||||
|
//
|
||||||
|
// Must be invoked after the tun device is active so that lighthouse.dns.host
|
||||||
|
// may bind to a nebula IP.
|
||||||
|
func (d *dnsServer) Start() {
|
||||||
|
if !d.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
started := make(chan struct{})
|
||||||
|
d.serverMu.Lock()
|
||||||
|
if d.ctx.Err() != nil {
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addr := d.addr
|
||||||
|
server := &dns.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Net: "udp",
|
||||||
|
Handler: d.mux,
|
||||||
|
NotifyStartedFunc: func() { close(started) },
|
||||||
|
}
|
||||||
|
d.server = server
|
||||||
|
d.started = started
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
|
||||||
|
// Per-invocation ctx watcher. Exits when Start does, so we don't leak a
|
||||||
|
// watcher per reload-driven restart.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
d.shutdownServer(server, started, "shutdown")
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
d.l.Info("Starting DNS responder", "dnsListener", addr)
|
||||||
|
err := server.ListenAndServe()
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
// If the listener never bound (bind error) NotifyStartedFunc never fires,
|
||||||
|
// so close started here to release any Stop caller waiting on it.
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
default:
|
||||||
|
close(started)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
d.l.Warn("Failed to run the DNS responder", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop shuts down the active server, if any. Idempotent.
|
||||||
|
func (d *dnsServer) Stop() {
|
||||||
|
d.serverMu.Lock()
|
||||||
|
srv := d.server
|
||||||
|
started := d.started
|
||||||
|
d.server = nil
|
||||||
|
d.started = nil
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
d.shutdownServer(srv, started, "stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query returns the address for the given name and query type. The second
|
||||||
|
// return value reports whether the name is known at all (in either A or AAAA),
|
||||||
|
// which lets callers distinguish NODATA from NXDOMAIN.
|
||||||
|
func (d *dnsServer) Query(q uint16, data string) (netip.Addr, bool) {
|
||||||
data = strings.ToLower(data)
|
data = strings.ToLower(data)
|
||||||
d.RLock()
|
d.RLock()
|
||||||
defer d.RUnlock()
|
defer d.RUnlock()
|
||||||
|
addr4, haveV4 := d.dnsMap4[data]
|
||||||
|
addr6, haveV6 := d.dnsMap6[data]
|
||||||
|
nameExists := haveV4 || haveV6
|
||||||
switch q {
|
switch q {
|
||||||
case dns.TypeA:
|
case dns.TypeA:
|
||||||
if r, ok := d.dnsMap4[data]; ok {
|
if haveV4 {
|
||||||
return r
|
return addr4, nameExists
|
||||||
}
|
}
|
||||||
case dns.TypeAAAA:
|
case dns.TypeAAAA:
|
||||||
if r, ok := d.dnsMap6[data]; ok {
|
if haveV6 {
|
||||||
return r
|
return addr6, nameExists
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return netip.Addr{}
|
return netip.Addr{}, nameExists
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) QueryCert(data string) string {
|
func (d *dnsServer) QueryCert(data string) string {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
ip, err := netip.ParseAddr(data[:len(data)-1])
|
ip, err := netip.ParseAddr(data[:len(data)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
@@ -80,8 +266,19 @@ func (d *dnsRecords) QueryCert(data string) string {
|
|||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clearRecords drops all DNS records.
|
||||||
|
func (d *dnsServer) clearRecords() {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
clear(d.dnsMap4)
|
||||||
|
clear(d.dnsMap6)
|
||||||
|
}
|
||||||
|
|
||||||
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
||||||
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
func (d *dnsServer) Add(host string, addresses []netip.Addr) {
|
||||||
|
if !d.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
host = strings.ToLower(host)
|
host = strings.ToLower(host)
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
@@ -101,7 +298,7 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
|
||||||
a, _, _ := net.SplitHostPort(addr)
|
a, _, _ := net.SplitHostPort(addr)
|
||||||
b, err := netip.ParseAddr(a)
|
b, err := netip.ParseAddr(a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -116,13 +313,24 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
|||||||
return d.myVpnAddrsTable.Contains(b)
|
return d.myVpnAddrsTable.Contains(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||||
|
debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug)
|
||||||
|
// Per RFC 2308 §2.2, a name that exists but has no record of the requested
|
||||||
|
// type must be answered with NOERROR and an empty answer section (NODATA),
|
||||||
|
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
|
||||||
|
// exist at all.
|
||||||
|
anyNameExists := false
|
||||||
for _, q := range m.Question {
|
for _, q := range m.Question {
|
||||||
switch q.Qtype {
|
switch q.Qtype {
|
||||||
case dns.TypeA, dns.TypeAAAA:
|
case dns.TypeA, dns.TypeAAAA:
|
||||||
qType := dns.TypeToString[q.Qtype]
|
qType := dns.TypeToString[q.Qtype]
|
||||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
if debugEnabled {
|
||||||
ip := d.Query(q.Qtype, q.Name)
|
d.l.Debug("DNS query", "type", qType, "name", q.Name)
|
||||||
|
}
|
||||||
|
ip, nameExists := d.Query(q.Qtype, q.Name)
|
||||||
|
if nameExists {
|
||||||
|
anyNameExists = true
|
||||||
|
}
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -134,7 +342,9 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
d.l.Debugf("Query for TXT %s", q.Name)
|
if debugEnabled {
|
||||||
|
d.l.Debug("DNS query", "type", "TXT", "name", q.Name)
|
||||||
|
}
|
||||||
ip := d.QueryCert(q.Name)
|
ip := d.QueryCert(q.Name)
|
||||||
if ip != "" {
|
if ip != "" {
|
||||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||||
@@ -145,12 +355,12 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.Answer) == 0 {
|
if len(m.Answer) == 0 && !anyNameExists {
|
||||||
m.Rcode = dns.RcodeNameError
|
m.Rcode = dns.RcodeNameError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
m.Compress = false
|
m.Compress = false
|
||||||
@@ -163,21 +373,6 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
|
||||||
dnsR = newDnsRecords(l, cs, hostMap)
|
|
||||||
|
|
||||||
// attach request handler func
|
|
||||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
|
||||||
reloadDns(l, c)
|
|
||||||
})
|
|
||||||
|
|
||||||
return func() {
|
|
||||||
startDns(l, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDnsServerAddr(c *config.C) string {
|
func getDnsServerAddr(c *config.C) string {
|
||||||
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
|
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
|
||||||
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
|
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
|
||||||
@@ -186,25 +381,3 @@ func getDnsServerAddr(c *config.C) string {
|
|||||||
}
|
}
|
||||||
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func startDns(l *logrus.Logger, c *config.C) {
|
|
||||||
dnsAddr = getDnsServerAddr(c)
|
|
||||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
|
||||||
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
|
||||||
err := dnsServer.ListenAndServe()
|
|
||||||
defer dnsServer.Shutdown()
|
|
||||||
if err != nil {
|
|
||||||
l.Errorf("Failed to start server: %s\n ", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func reloadDns(l *logrus.Logger, c *config.C) {
|
|
||||||
if dnsAddr == getDnsServerAddr(c) {
|
|
||||||
l.Debug("No DNS server config change detected")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Debug("Restarting DNS server")
|
|
||||||
dnsServer.Shutdown()
|
|
||||||
go startDns(l, c)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,19 +1,43 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type stubDNSWriter struct{}
|
||||||
|
|
||||||
|
func (stubDNSWriter) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||||
|
func (stubDNSWriter) RemoteAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5353}
|
||||||
|
}
|
||||||
|
func (stubDNSWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
func (stubDNSWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||||
|
func (stubDNSWriter) Close() error { return nil }
|
||||||
|
func (stubDNSWriter) TsigStatus() error { return nil }
|
||||||
|
func (stubDNSWriter) TsigTimersOnly(bool) {}
|
||||||
|
func (stubDNSWriter) Hijack() {}
|
||||||
|
|
||||||
func TestParsequery(t *testing.T) {
|
func TestParsequery(t *testing.T) {
|
||||||
l := logrus.New()
|
l := slog.New(slog.DiscardHandler)
|
||||||
hostMap := &HostMap{}
|
hostMap := &HostMap{}
|
||||||
ds := newDnsRecords(l, &CertState{}, hostMap)
|
ds := &dnsServer{
|
||||||
|
l: l,
|
||||||
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
|
hostMap: hostMap,
|
||||||
|
}
|
||||||
|
ds.enabled.Store(true)
|
||||||
addrs := []netip.Addr{
|
addrs := []netip.Addr{
|
||||||
netip.MustParseAddr("1.2.3.4"),
|
netip.MustParseAddr("1.2.3.4"),
|
||||||
netip.MustParseAddr("1.2.3.5"),
|
netip.MustParseAddr("1.2.3.5"),
|
||||||
@@ -21,18 +45,56 @@ func TestParsequery(t *testing.T) {
|
|||||||
netip.MustParseAddr("fd01::25"),
|
netip.MustParseAddr("fd01::25"),
|
||||||
}
|
}
|
||||||
ds.Add("test.com.com", addrs)
|
ds.Add("test.com.com", addrs)
|
||||||
|
ds.Add("v4only.com.com", []netip.Addr{netip.MustParseAddr("1.2.3.6")})
|
||||||
|
ds.Add("v6only.com.com", []netip.Addr{netip.MustParseAddr("fd01::26")})
|
||||||
|
|
||||||
m := &dns.Msg{}
|
m := &dns.Msg{}
|
||||||
m.SetQuestion("test.com.com", dns.TypeA)
|
m.SetQuestion("test.com.com", dns.TypeA)
|
||||||
ds.parseQuery(m, nil)
|
ds.parseQuery(m, nil)
|
||||||
assert.NotNil(t, m.Answer)
|
assert.NotNil(t, m.Answer)
|
||||||
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
|
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
||||||
|
|
||||||
m = &dns.Msg{}
|
m = &dns.Msg{}
|
||||||
m.SetQuestion("test.com.com", dns.TypeAAAA)
|
m.SetQuestion("test.com.com", dns.TypeAAAA)
|
||||||
ds.parseQuery(m, nil)
|
ds.parseQuery(m, nil)
|
||||||
assert.NotNil(t, m.Answer)
|
assert.NotNil(t, m.Answer)
|
||||||
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
|
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
||||||
|
|
||||||
|
// A known name with no record of the requested type should return NODATA
|
||||||
|
// (NOERROR with empty answer), not NXDOMAIN.
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.SetQuestion("v4only.com.com", dns.TypeAAAA)
|
||||||
|
ds.parseQuery(m, nil)
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
||||||
|
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.SetQuestion("v6only.com.com", dns.TypeA)
|
||||||
|
ds.parseQuery(m, nil)
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
||||||
|
|
||||||
|
// An unknown name should still return NXDOMAIN.
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.SetQuestion("unknown.com.com", dns.TypeA)
|
||||||
|
ds.parseQuery(m, nil)
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, m.Rcode)
|
||||||
|
|
||||||
|
// short lookups should not fail
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.Question = []dns.Question{{Name: "", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
|
||||||
|
ds.parseQuery(m, stubDNSWriter{})
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, m.Rcode)
|
||||||
|
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.Question = []dns.Question{{Name: ".", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
|
||||||
|
ds.parseQuery(m, stubDNSWriter{})
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, m.Rcode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getDnsServerAddr(t *testing.T) {
|
func Test_getDnsServerAddr(t *testing.T) {
|
||||||
@@ -71,3 +133,208 @@ func Test_getDnsServerAddr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
|
||||||
|
t.Helper()
|
||||||
|
sl := slog.New(slog.DiscardHandler)
|
||||||
|
ds := &dnsServer{
|
||||||
|
l: sl,
|
||||||
|
ctx: context.Background(),
|
||||||
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
|
hostMap: &HostMap{},
|
||||||
|
}
|
||||||
|
ds.mux = dns.NewServeMux()
|
||||||
|
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
||||||
|
return ds, config.NewC(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
|
||||||
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
|
"am_lighthouse": amLighthouse,
|
||||||
|
"serve_dns": serveDns,
|
||||||
|
"dns": map[string]any{
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_initial_disabled(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", true, false)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
assert.False(t, ds.enabled.Load())
|
||||||
|
assert.Equal(t, "127.0.0.1:0", ds.addr)
|
||||||
|
assert.Nil(t, ds.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_initial_enabled(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
assert.True(t, ds.enabled.Load())
|
||||||
|
assert.Equal(t, "127.0.0.1:0", ds.addr)
|
||||||
|
// initial never starts a runner; that's Control.Start's job
|
||||||
|
assert.Nil(t, ds.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", false, true)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
// Wants DNS but isn't a lighthouse: gated off, no runner.
|
||||||
|
assert.False(t, ds.enabled.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_sameAddr_noOp(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
// No server running yet, no addr change. Reload should not spawn anything.
|
||||||
|
require.NoError(t, ds.reload(c, false))
|
||||||
|
assert.True(t, ds.enabled.Load())
|
||||||
|
assert.Nil(t, ds.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_StartStop_lifecycle(t *testing.T) {
|
||||||
|
// Bind to a real (random) UDP port so we exercise the actual
|
||||||
|
// ListenAndServe + Shutdown plumbing including the started-chan race fix.
|
||||||
|
port := freeUDPPort(t)
|
||||||
|
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", port, true, true)
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Start()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitFor(t, func() bool {
|
||||||
|
ds.serverMu.Lock()
|
||||||
|
started := ds.started
|
||||||
|
ds.serverMu.Unlock()
|
||||||
|
if started == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ds.Stop()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Start did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
|
||||||
|
// Stop called immediately after Start should not deadlock even if bind
|
||||||
|
// hasn't completed yet. This exercises the started-chan close-on-bind-fail
|
||||||
|
// path: by binding to an obviously bad port (privileged) we get a fast
|
||||||
|
// bind error before NotifyStartedFunc fires.
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
// Use a port that should fail to bind (negative would be invalid, use a
|
||||||
|
// host that won't resolve to ensure listenUDP fails quickly).
|
||||||
|
setDnsConfig(c, "256.256.256.256", "53", true, true)
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Start()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give Start a moment to attempt the bind and fail.
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Bind failed and Start returned; Stop should be a no-op.
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("Start did not return after a bad bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
stopped := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Stop()
|
||||||
|
close(stopped)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-stopped:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("Stop hung after a failed bind")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
|
||||||
|
port := freeUDPPort(t)
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", port, true, true)
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
|
||||||
|
startReturned := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Start()
|
||||||
|
close(startReturned)
|
||||||
|
}()
|
||||||
|
waitForBind(t, ds)
|
||||||
|
|
||||||
|
// Toggle serve_dns off; reload should shut the running server down.
|
||||||
|
setDnsConfig(c, "127.0.0.1", port, true, false)
|
||||||
|
require.NoError(t, ds.reload(c, false))
|
||||||
|
select {
|
||||||
|
case <-startReturned:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Start did not return after reload disabled DNS")
|
||||||
|
}
|
||||||
|
assert.False(t, ds.enabled.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func freeUDPPort(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
port := conn.LocalAddr().(*net.UDPAddr).Port
|
||||||
|
require.NoError(t, conn.Close())
|
||||||
|
return strconv.Itoa(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForBind(t *testing.T, ds *dnsServer) {
|
||||||
|
t.Helper()
|
||||||
|
waitFor(t, func() bool {
|
||||||
|
ds.serverMu.Lock()
|
||||||
|
started := ds.started
|
||||||
|
ds.serverMu.Unlock()
|
||||||
|
if started == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitFor(t *testing.T, cond func() bool) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if cond() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatal("timed out waiting for condition")
|
||||||
|
}
|
||||||
|
|||||||
577
e2e/handshake_manager_test.go
Normal file
577
e2e/handshake_manager_test.go
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
//go:build e2e_testing
|
||||||
|
// +build e2e_testing
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/cert_test"
|
||||||
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// makeHandshakePacket creates a handshake packet with the given parameters.
|
||||||
|
func makeHandshakePacket(from, to netip.AddrPort, subtype header.MessageSubType, remoteIndex uint32, counter uint64) *udp.Packet {
|
||||||
|
data := make([]byte, 200)
|
||||||
|
header.Encode(data, header.Version, header.Handshake, subtype, remoteIndex, counter)
|
||||||
|
for i := header.Len; i < len(data); i++ {
|
||||||
|
data[i] = byte(i)
|
||||||
|
}
|
||||||
|
return &udp.Packet{To: to, From: from, Data: data}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeRetransmitDuplicate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Verify the responder correctly handles receiving the same msg1 multiple times
|
||||||
|
// (retransmission). The duplicate goes through CheckAndComplete -> ErrAlreadySeen
|
||||||
|
// and the cached response is resent.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake from me to them")
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
|
||||||
|
|
||||||
|
t.Log("Grab my msg1")
|
||||||
|
msg1 := myControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Inject msg1 into them, first time")
|
||||||
|
theirControl.InjectUDPPacket(msg1)
|
||||||
|
_ = theirControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Inject the SAME msg1 again, tests ErrAlreadySeen path")
|
||||||
|
theirControl.InjectUDPPacket(msg1)
|
||||||
|
resp2 := theirControl.GetFromUDP(true)
|
||||||
|
assert.NotNil(t, resp2, "should get cached response on duplicate msg1")
|
||||||
|
|
||||||
|
t.Log("Complete handshake with cached response")
|
||||||
|
myControl.InjectUDPPacket(resp2)
|
||||||
|
myControl.WaitForType(1, 0, theirControl)
|
||||||
|
|
||||||
|
t.Log("Drain cached packet and verify tunnel works")
|
||||||
|
cachedPacket := theirControl.GetFromTun(true)
|
||||||
|
assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Verify only one tunnel exists on each side")
|
||||||
|
assert.Len(t, myControl.ListHostmapHosts(false), 1)
|
||||||
|
assert.Len(t, theirControl.ListHostmapHosts(false), 1)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeTruncatedPacketRecovery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Verify that a truncated handshake packet is ignored and the real
|
||||||
|
// packet can still complete the handshake.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake")
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
|
||||||
|
|
||||||
|
t.Log("Get msg1 and deliver to responder")
|
||||||
|
msg1 := myControl.GetFromUDP(true)
|
||||||
|
theirControl.InjectUDPPacket(msg1)
|
||||||
|
|
||||||
|
t.Log("Get the real response")
|
||||||
|
realResp := theirControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Truncate the response and inject, should be ignored")
|
||||||
|
truncResp := realResp.Copy()
|
||||||
|
truncResp.Data = truncResp.Data[:header.Len]
|
||||||
|
myControl.InjectUDPPacket(truncResp)
|
||||||
|
|
||||||
|
t.Log("Verify pending handshake survived the truncated packet")
|
||||||
|
assert.NotEmpty(t, myControl.ListHostmapHosts(true), "pending handshake should still exist")
|
||||||
|
|
||||||
|
t.Log("Inject real response, should complete handshake")
|
||||||
|
myControl.InjectUDPPacket(realResp)
|
||||||
|
myControl.WaitForType(1, 0, theirControl)
|
||||||
|
|
||||||
|
t.Log("Drain and verify tunnel")
|
||||||
|
cachedPacket := theirControl.GetFromTun(true)
|
||||||
|
assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeOrphanedMsg2Dropped(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// A msg2 arriving with no matching pending index should be silently dropped
|
||||||
|
// with no response sent and no state changes.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Complete a normal handshake")
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
|
||||||
|
r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Record hostmap state")
|
||||||
|
myIndexes := len(myControl.ListHostmapIndexes(false))
|
||||||
|
|
||||||
|
t.Log("Inject a fake msg2 with unknown RemoteIndex")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0xDEADBEEF, 2))
|
||||||
|
|
||||||
|
t.Log("Verify no new indexes created")
|
||||||
|
assert.Equal(t, myIndexes, len(myControl.ListHostmapIndexes(false)))
|
||||||
|
|
||||||
|
t.Log("Verify no UDP response was sent")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false), "should not send a response to orphaned msg2")
|
||||||
|
|
||||||
|
t.Log("Verify existing tunnel still works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeUnknownMessageCounter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// A handshake packet with an unexpected message counter should be silently
|
||||||
|
// dropped with no side effects and no UDP response.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Inject handshake with MessageCounter=3")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 3))
|
||||||
|
|
||||||
|
t.Log("Inject handshake with MessageCounter=99")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 99))
|
||||||
|
|
||||||
|
t.Log("Verify no tunnels or pending handshakes")
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
|
||||||
|
t.Log("Verify no UDP response was sent")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeUnknownSubtype(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// A handshake packet with an unknown subtype should be silently dropped.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Inject handshake with unknown subtype 99")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.MessageSubType(99), 0, 1))
|
||||||
|
|
||||||
|
t.Log("Verify no tunnels or pending handshakes")
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
|
||||||
|
t.Log("Verify no UDP response was sent")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeLateResponse(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// After a handshake times out, a late response should be silently ignored
|
||||||
|
// with no new tunnels created.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{
|
||||||
|
"handshakes": m{
|
||||||
|
"try_interval": "200ms",
|
||||||
|
"retries": 2,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake from me")
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
|
||||||
|
|
||||||
|
t.Log("Grab msg1 but don't deliver")
|
||||||
|
msg1 := myControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Wait for handshake to time out")
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
myControl.GetFromUDP(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Confirm no pending handshakes remain")
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
|
||||||
|
t.Log("Deliver old msg1 to them, they create a tunnel")
|
||||||
|
theirControl.InjectUDPPacket(msg1)
|
||||||
|
resp := theirControl.GetFromUDP(true)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
|
||||||
|
t.Log("Inject late response into me, should be ignored")
|
||||||
|
myControl.InjectUDPPacket(resp)
|
||||||
|
|
||||||
|
t.Log("No tunnel should exist on my side")
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeSelfConnectionRejected(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Verify that a node rejects a handshake containing its own VPN IP in the
|
||||||
|
// peer cert. We do this by sending the initiator's own msg1 back to itself.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
|
||||||
|
// Need a lighthouse entry to trigger a handshake
|
||||||
|
myControl.InjectLightHouseAddr(netip.MustParseAddr("10.128.0.2"), netip.MustParseAddrPort("10.0.0.2:4242"))
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake from me")
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
|
||||||
|
msg1 := myControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Drain any handshake retransmits before injecting")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
for myControl.GetFromUDP(false) != nil {
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Feed my own msg1 back to me as if it came from someone else")
|
||||||
|
selfMsg := msg1.Copy()
|
||||||
|
selfMsg.From = netip.MustParseAddrPort("10.0.0.99:4242")
|
||||||
|
selfMsg.To = myUdpAddr
|
||||||
|
myControl.InjectUDPPacket(selfMsg)
|
||||||
|
|
||||||
|
t.Log("Verify no response was sent (self-connection rejected)")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
// Drain any further retransmits from the original handshake, then check
|
||||||
|
// that none of them are a handshake response (MessageCounter=2)
|
||||||
|
h := &header.H{}
|
||||||
|
for {
|
||||||
|
p := myControl.GetFromUDP(false)
|
||||||
|
if p == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_ = h.Parse(p.Data)
|
||||||
|
assert.NotEqual(t, uint64(2), h.MessageCounter,
|
||||||
|
"should not send a stage 2 response to self-connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Verify no tunnel to myself was created")
|
||||||
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeMessageCounter0Dropped(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// MessageCounter=0 is not a valid handshake message and should be dropped.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
_, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
|
||||||
|
t.Log("Inject handshake with MessageCounter=0")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 0))
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeRemoteAllowList(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Verify that a handshake from a blocked underlay IP is dropped with no
|
||||||
|
// response and no state changes. Then verify the same packet from an
|
||||||
|
// allowed IP succeeds.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{
|
||||||
|
"lighthouse": m{
|
||||||
|
"remote_allow_list": m{
|
||||||
|
"10.0.0.0/8": true,
|
||||||
|
"0.0.0.0/0": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake from them")
|
||||||
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi")))
|
||||||
|
msg1 := theirControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Rewrite the source to a blocked IP and inject")
|
||||||
|
blockedMsg := msg1.Copy()
|
||||||
|
blockedMsg.From = netip.MustParseAddrPort("192.168.1.1:4242")
|
||||||
|
myControl.InjectUDPPacket(blockedMsg)
|
||||||
|
|
||||||
|
t.Log("Verify no tunnel, no pending, no response from blocked source")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false), "should not respond to blocked source")
|
||||||
|
|
||||||
|
t.Log("Now inject the real packet from the allowed source")
|
||||||
|
myControl.InjectUDPPacket(msg1)
|
||||||
|
|
||||||
|
t.Log("Verify handshake completes from allowed source")
|
||||||
|
resp := myControl.GetFromUDP(true)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
theirControl.InjectUDPPacket(resp)
|
||||||
|
theirControl.WaitForType(1, 0, myControl)
|
||||||
|
|
||||||
|
t.Log("Drain cached packet and verify tunnel works")
|
||||||
|
cachedPacket := myControl.GetFromTun(true)
|
||||||
|
assertUdpPacket(t, []byte("Hi"), cachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// When a duplicate msg1 arrives via ErrAlreadySeen, verify the tunnel
|
||||||
|
// remains functional and hostmap index count is stable.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Complete a normal handshake via the router")
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")))
|
||||||
|
r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Record hostmap state")
|
||||||
|
theirIndexes := len(theirControl.ListHostmapIndexes(false))
|
||||||
|
hi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, hi)
|
||||||
|
originalRemote := hi.CurrentRemote
|
||||||
|
|
||||||
|
t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)")
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam")))
|
||||||
|
r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
|
||||||
|
t.Log("Verify tunnel still works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Verify remote is still valid and index count is stable")
|
||||||
|
hi2 := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, hi2)
|
||||||
|
assert.Equal(t, originalRemote, hi2.CurrentRemote)
|
||||||
|
assert.Equal(t, theirIndexes, len(theirControl.ListHostmapIndexes(false)),
|
||||||
|
"no extra indexes should be created from ErrAlreadySeen")
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeWrongResponderPacketStore(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Verify that when the wrong host responds, the cached packets are
|
||||||
|
// transferred to the new handshake, the evil tunnel is closed, evil's
|
||||||
|
// address is blocked, and the correct tunnel is eventually established.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
|
||||||
|
evilControl, evilVpnIpNet, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl, evilControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
evilControl.Start()
|
||||||
|
|
||||||
|
t.Log("Send multiple packets to them (cached during handshake)")
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1")))
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2")))
|
||||||
|
|
||||||
|
t.Log("Route until evil tunnel is closed")
|
||||||
|
h := &header.H{}
|
||||||
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
|
if err := h.Parse(p.Data); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if h.Type == header.CloseTunnel && p.To == evilUdpAddr {
|
||||||
|
return router.RouteAndExit
|
||||||
|
}
|
||||||
|
return router.KeepRouting
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Log("Verify evil's address is blocked in the new pending handshake")
|
||||||
|
pendingHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
|
||||||
|
if pendingHI != nil {
|
||||||
|
assert.NotContains(t, pendingHI.RemoteAddrs, evilUdpAddr,
|
||||||
|
"evil's address should be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Inject correct lighthouse addr for them")
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
t.Log("Route until cached packets arrive at the real them")
|
||||||
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
assert.NotNil(t, p, "a cached packet should be delivered to the correct host")
|
||||||
|
|
||||||
|
t.Log("Verify the correct host has a tunnel")
|
||||||
|
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||||
|
|
||||||
|
t.Log("Verify no hostinfo artifacts from evil remain")
|
||||||
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), true),
|
||||||
|
"no pending hostinfo for evil")
|
||||||
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), false),
|
||||||
|
"no main hostinfo for evil")
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
evilControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeRelayComplete(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Verify that a relay handshake completes correctly and relay state is
|
||||||
|
// properly maintained on all three nodes.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
relayControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake via relay")
|
||||||
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay")))
|
||||||
|
|
||||||
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
|
t.Log("Verify bidirectional tunnel via relay")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Verify relay state on my side shows relay-to-me")
|
||||||
|
myHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, myHI)
|
||||||
|
assert.NotEmpty(t, myHI.CurrentRelaysToMe, "should have relay-to-me for them")
|
||||||
|
|
||||||
|
t.Log("Verify relay state on their side shows relay-to-me")
|
||||||
|
theirHI := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, theirHI)
|
||||||
|
assert.NotEmpty(t, theirHI.CurrentRelaysToMe, "should have relay-to-me for me")
|
||||||
|
|
||||||
|
t.Log("Verify relay node shows through-me relays")
|
||||||
|
relayHI := relayControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, relayHI)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
relayControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: Relay V1 cert + IPv6 rejection is not tested here because
|
||||||
|
// BuildTunUDPPacket from a V4 node to a V6 address panics in the test
|
||||||
|
// framework. The check is in handshake_manager.go handleOutbound relay
|
||||||
|
// logic (lines ~304-313): if the relay host has a V1 cert and either
|
||||||
|
// address is IPv6, the relay is skipped.
|
||||||
|
|
||||||
|
// NOTE: Relay reestablishment (Disestablished state transition) is covered
|
||||||
|
// by the existing TestReestablishRelays in handshakes_test.go.
|
||||||
@@ -11,12 +11,12 @@ import (
|
|||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -40,11 +40,22 @@ func BenchmarkHotPath(b *testing.B) {
|
|||||||
r.CancelFlowLogs()
|
r.CancelFlowLogs()
|
||||||
|
|
||||||
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
// Pre-build the IP packet bytes once so the bench measures the data plane,
|
||||||
|
// not gopacket SerializeLayers overhead.
|
||||||
|
prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
|
// EnableFanIn switches the router to a 0-alloc routing path. Required
|
||||||
|
// for hot-path benchmarks; would conflict with GetFromUDP-using tests.
|
||||||
|
r.EnableFanIn()
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(prebuilt)
|
||||||
_ = r.RouteForAllUntilTxTun(theirControl)
|
// Release the TUN-side bytes back to the harness freelist; the bench
|
||||||
|
// just confirms a packet arrived, the contents aren't inspected.
|
||||||
|
overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl))
|
||||||
}
|
}
|
||||||
|
|
||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
@@ -72,11 +83,15 @@ func BenchmarkHotPathRelay(b *testing.B) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
|
||||||
|
prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
r.EnableFanIn()
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(prebuilt)
|
||||||
_ = r.RouteForAllUntilTxTun(theirControl)
|
overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl))
|
||||||
}
|
}
|
||||||
|
|
||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
@@ -85,6 +100,7 @@ func BenchmarkHotPathRelay(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGoodHandshake(t *testing.T) {
|
func TestGoodHandshake(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
@@ -97,7 +113,7 @@ func TestGoodHandshake(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||||
@@ -135,6 +151,7 @@ func TestGoodHandshake(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGoodHandshakeNoOverlap(t *testing.T) {
|
func TestGoodHandshakeNoOverlap(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
|
||||||
@@ -170,6 +187,7 @@ func TestGoodHandshakeNoOverlap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWrongResponderHandshake(t *testing.T) {
|
func TestWrongResponderHandshake(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
|
||||||
@@ -189,7 +207,7 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||||||
evilControl.Start()
|
evilControl.Start()
|
||||||
|
|
||||||
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
|
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
@@ -246,6 +264,7 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
|
func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
|
||||||
@@ -270,7 +289,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
|
|||||||
evilControl.Start()
|
evilControl.Start()
|
||||||
|
|
||||||
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
|
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
@@ -328,6 +347,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStage1Race(t *testing.T) {
|
func TestStage1Race(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
|
// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
|
||||||
// But will eventually collapse down to a single tunnel
|
// But will eventually collapse down to a single tunnel
|
||||||
|
|
||||||
@@ -348,8 +368,8 @@ func TestStage1Race(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake to start on both me and them")
|
t.Log("Trigger a handshake to start on both me and them")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
|
||||||
|
|
||||||
t.Log("Get both stage 1 handshake packets")
|
t.Log("Get both stage 1 handshake packets")
|
||||||
myHsForThem := myControl.GetFromUDP(true)
|
myHsForThem := myControl.GetFromUDP(true)
|
||||||
@@ -408,6 +428,7 @@ func TestStage1Race(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUncleanShutdownRaceLoser(t *testing.T) {
|
func TestUncleanShutdownRaceLoser(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
@@ -425,7 +446,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
r.Log("Trigger a handshake from me to them")
|
r.Log("Trigger a handshake from me to them")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
@@ -436,7 +457,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
|
|||||||
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
||||||
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
||||||
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")))
|
||||||
p = r.RouteForAllUntilTxTun(theirControl)
|
p = r.RouteForAllUntilTxTun(theirControl)
|
||||||
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
@@ -457,6 +478,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUncleanShutdownRaceWinner(t *testing.T) {
|
func TestUncleanShutdownRaceWinner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
@@ -474,7 +496,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
r.Log("Trigger a handshake from me to them")
|
r.Log("Trigger a handshake from me to them")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
@@ -486,7 +508,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
|
|||||||
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
||||||
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
||||||
|
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")))
|
||||||
p = r.RouteForAllUntilTxTun(myControl)
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
|
||||||
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
|
||||||
@@ -508,6 +530,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRelays(t *testing.T) {
|
func TestRelays(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
@@ -528,7 +551,7 @@ func TestRelays(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
@@ -537,6 +560,7 @@ func TestRelays(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRelaysDontCareAboutIps(t *testing.T) {
|
func TestRelaysDontCareAboutIps(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
|
||||||
@@ -557,7 +581,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
@@ -566,6 +590,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestReestablishRelays(t *testing.T) {
|
func TestReestablishRelays(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
@@ -586,14 +611,14 @@ func TestReestablishRelays(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
t.Log("Ensure packet traversal from them to me via the relay")
|
t.Log("Ensure packet traversal from them to me via the relay")
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
|
||||||
|
|
||||||
p = r.RouteForAllUntilTxTun(myControl)
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
@@ -608,7 +633,7 @@ func TestReestablishRelays(t *testing.T) {
|
|||||||
for curIndexes >= start {
|
for curIndexes >= start {
|
||||||
curIndexes = len(myControl.GetHostmap().Indexes)
|
curIndexes = len(myControl.GetHostmap().Indexes)
|
||||||
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
|
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")))
|
||||||
|
|
||||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
return router.RouteAndExit
|
return router.RouteAndExit
|
||||||
@@ -625,7 +650,7 @@ func TestReestablishRelays(t *testing.T) {
|
|||||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
p = r.RouteForAllUntilTxTun(theirControl)
|
p = r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
@@ -660,7 +685,7 @@ func TestReestablishRelays(t *testing.T) {
|
|||||||
t.Log("Assert the tunnel works the other way, too")
|
t.Log("Assert the tunnel works the other way, too")
|
||||||
for {
|
for {
|
||||||
t.Log("RouteForAllUntilTxTun")
|
t.Log("RouteForAllUntilTxTun")
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
|
||||||
|
|
||||||
p = r.RouteForAllUntilTxTun(myControl)
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
@@ -697,6 +722,7 @@ func TestReestablishRelays(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStage1RaceRelays(t *testing.T) {
|
func TestStage1RaceRelays(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
@@ -729,8 +755,8 @@ func TestStage1RaceRelays(t *testing.T) {
|
|||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
|
||||||
|
|
||||||
r.Log("Wait for a packet from them to me")
|
r.Log("Wait for a packet from them to me")
|
||||||
p := r.RouteForAllUntilTxTun(myControl)
|
p := r.RouteForAllUntilTxTun(myControl)
|
||||||
@@ -744,12 +770,12 @@ func TestStage1RaceRelays(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStage1RaceRelays2(t *testing.T) {
|
func TestStage1RaceRelays2(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
@@ -771,49 +797,41 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
r.Log("Get a tunnel between me and relay")
|
r.Log("Get a tunnel between me and relay")
|
||||||
l.Info("Get a tunnel between me and relay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Get a tunnel between them and relay")
|
r.Log("Get a tunnel between them and relay")
|
||||||
l.Info("Get a tunnel between them and relay")
|
|
||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
||||||
l.Info("Trigger a handshake from both them and me via relay to them and me")
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
|
||||||
|
|
||||||
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
||||||
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
||||||
|
|
||||||
r.Log("Wait for a packet from them to me")
|
r.Log("Wait for a packet from them to me; myControl")
|
||||||
l.Info("Wait for a packet from them to me; myControl")
|
|
||||||
r.RouteForAllUntilTxTun(myControl)
|
r.RouteForAllUntilTxTun(myControl)
|
||||||
l.Info("Wait for a packet from them to me; theirControl")
|
r.Log("Wait for a packet from them to me; theirControl")
|
||||||
r.RouteForAllUntilTxTun(theirControl)
|
r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
l.Info("Assert the tunnel works")
|
|
||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
|
||||||
t.Log("Wait until we remove extra tunnels")
|
t.Log("Wait until we remove extra tunnels")
|
||||||
l.Info("Wait until we remove extra tunnels")
|
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
|
||||||
l.WithFields(
|
len(myControl.GetHostmap().Indexes),
|
||||||
logrus.Fields{
|
len(theirControl.GetHostmap().Indexes),
|
||||||
"myControl": len(myControl.GetHostmap().Indexes),
|
len(relayControl.GetHostmap().Indexes),
|
||||||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
)
|
||||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
|
||||||
}).Info("Waiting for hostinfos to be removed...")
|
|
||||||
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
||||||
retries := 60
|
retries := 60
|
||||||
for hostInfos > 6 && retries > 0 {
|
for hostInfos > 6 && retries > 0 {
|
||||||
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
||||||
l.WithFields(
|
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
|
||||||
logrus.Fields{
|
len(myControl.GetHostmap().Indexes),
|
||||||
"myControl": len(myControl.GetHostmap().Indexes),
|
len(theirControl.GetHostmap().Indexes),
|
||||||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
len(relayControl.GetHostmap().Indexes),
|
||||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
)
|
||||||
}).Info("Waiting for hostinfos to be removed...")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
t.Log("Connection manager hasn't ticked yet")
|
t.Log("Connection manager hasn't ticked yet")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -821,7 +839,6 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
l.Info("Assert the tunnel works")
|
|
||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
|
||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
@@ -830,6 +847,7 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRehandshakingRelays(t *testing.T) {
|
func TestRehandshakingRelays(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
@@ -850,7 +868,7 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
@@ -933,6 +951,7 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRehandshakingRelaysPrimary(t *testing.T) {
|
func TestRehandshakingRelaysPrimary(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
|
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
|
||||||
@@ -954,7 +973,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
@@ -1037,6 +1056,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRehandshaking(t *testing.T) {
|
func TestRehandshaking(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
|
||||||
@@ -1132,6 +1152,7 @@ func TestRehandshaking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRehandshakingLoser(t *testing.T) {
|
func TestRehandshakingLoser(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
|
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
|
||||||
// Should be the one with the new certificate
|
// Should be the one with the new certificate
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
@@ -1230,6 +1251,7 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRaceRegression(t *testing.T) {
|
func TestRaceRegression(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
// This test forces stage 1, stage 2, stage 1 to be received by me from them
|
// This test forces stage 1, stage 2, stage 1 to be received by me from them
|
||||||
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
|
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
|
||||||
// caused a cross-linked hostinfo
|
// caused a cross-linked hostinfo
|
||||||
@@ -1253,8 +1275,8 @@ func TestRaceRegression(t *testing.T) {
|
|||||||
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
|
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
|
||||||
|
|
||||||
t.Log("Start both handshakes")
|
t.Log("Start both handshakes")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")))
|
||||||
|
|
||||||
t.Log("Get both stage 1")
|
t.Log("Get both stage 1")
|
||||||
myStage1ForThem := myControl.GetFromUDP(true)
|
myStage1ForThem := myControl.GetFromUDP(true)
|
||||||
@@ -1290,6 +1312,7 @@ func TestRaceRegression(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
||||||
|
|
||||||
@@ -1330,6 +1353,7 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
||||||
|
|
||||||
@@ -1369,7 +1393,84 @@ func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLighthouseUpdateOnReload(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
// Create the lighthouse
|
||||||
|
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh", "10.128.0.1/24", m{"lighthouse": m{"am_lighthouse": true}})
|
||||||
|
|
||||||
|
// Create a client with NO lighthouse configured and a long update interval.
|
||||||
|
// The initial SendUpdate at startup will be a no-op since no lighthouses are known.
|
||||||
|
myControl, myVpnIpNet, _, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.2/24", m{
|
||||||
|
"lighthouse": m{
|
||||||
|
"interval": 600,
|
||||||
|
"local_allow_list": m{
|
||||||
|
"10.0.0.0/24": true,
|
||||||
|
"::/0": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
r := router.NewR(t, lhControl, myControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
lhControl.Start()
|
||||||
|
myControl.Start()
|
||||||
|
|
||||||
|
// Drain any startup packets (there should be none meaningful)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
// Verify lighthouse has no knowledge of the client
|
||||||
|
assert.Nil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr()))
|
||||||
|
|
||||||
|
// Build a new config that adds the lighthouse
|
||||||
|
newSettings := make(m)
|
||||||
|
for k, v := range myConfig.Settings {
|
||||||
|
newSettings[k] = v
|
||||||
|
}
|
||||||
|
newSettings["static_host_map"] = m{
|
||||||
|
lhVpnIpNet[0].Addr().String(): []any{lhUdpAddr.String()},
|
||||||
|
}
|
||||||
|
newSettings["lighthouse"] = m{
|
||||||
|
"hosts": []any{lhVpnIpNet[0].Addr().String()},
|
||||||
|
"interval": 600,
|
||||||
|
"local_allow_list": m{
|
||||||
|
"10.0.0.0/24": true,
|
||||||
|
"::/0": false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newCfg, err := yaml.Marshal(newSettings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reload the config. The lighthouse.hosts change triggers TriggerUpdate,
|
||||||
|
// which wakes the update worker. It calls SendUpdate, initiating a
|
||||||
|
// handshake to the new lighthouse and caching the HostUpdateNotification.
|
||||||
|
require.NoError(t, myConfig.ReloadConfigString(string(newCfg)))
|
||||||
|
|
||||||
|
// Route until the lighthouse receives the HostUpdateNotification.
|
||||||
|
// This covers: handshake stage 1, stage 2, then the cached update.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
r.RouteForAllUntilAfterMsgTypeTo(lhControl, header.LightHouse, 0)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for lighthouse update after config reload")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify lighthouse now has the client's addresses
|
||||||
|
assert.NotNil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr()))
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", lhControl, myControl)
|
||||||
|
lhControl.Stop()
|
||||||
|
myControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
unsafePrefix := "192.168.6.0/24"
|
unsafePrefix := "192.168.6.0/24"
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
|
||||||
@@ -1391,7 +1492,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
||||||
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||||
@@ -1419,7 +1520,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
|||||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
|
||||||
|
|
||||||
//reply
|
//reply
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")))
|
||||||
//wait for reply
|
//wait for reply
|
||||||
theirControl.WaitForType(1, 0, myControl)
|
theirControl.WaitForType(1, 0, myControl)
|
||||||
theirCachedPacket := myControl.GetFromTun(true)
|
theirCachedPacket := myControl.GetFromTun(true)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -12,15 +11,18 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"go.yaml.in/yaml/v3"
|
||||||
@@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
|||||||
"port": udpAddr.Port(),
|
"port": udpAddr.Port(),
|
||||||
},
|
},
|
||||||
"logging": m{
|
"logging": m{
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
|
"level": testLogLevelName(),
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
},
|
||||||
"timers": m{
|
"timers": m{
|
||||||
"pending_deletion_interval": 2,
|
"pending_deletion_interval": 2,
|
||||||
@@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o
|
|||||||
"port": udpAddr.Port(),
|
"port": udpAddr.Port(),
|
||||||
},
|
},
|
||||||
"logging": m{
|
"logging": m{
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
"level": testLogLevelName(),
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
},
|
||||||
"timers": m{
|
"timers": m{
|
||||||
"pending_deletion_interval": 2,
|
"pending_deletion_interval": 2,
|
||||||
@@ -294,12 +294,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
|
|||||||
|
|
||||||
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
||||||
// Send a packet from them to me
|
// Send a packet from them to me
|
||||||
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
|
controlB.InjectTunPacket(BuildTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")))
|
||||||
bPacket := r.RouteForAllUntilTxTun(controlA)
|
bPacket := r.RouteForAllUntilTxTun(controlA)
|
||||||
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
|
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
|
||||||
|
|
||||||
// And once more from me to them
|
// And once more from me to them
|
||||||
controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
|
controlA.InjectTunPacket(BuildTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")))
|
||||||
aPacket := r.RouteForAllUntilTxTun(controlB)
|
aPacket := r.RouteForAllUntilTxTun(controlB)
|
||||||
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
||||||
}
|
}
|
||||||
@@ -379,24 +379,87 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
|
|||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestLogger() *logrus.Logger {
|
func NewTestLogger() *slog.Logger {
|
||||||
l := logrus.New()
|
|
||||||
|
|
||||||
v := os.Getenv("TEST_LOGS")
|
v := os.Getenv("TEST_LOGS")
|
||||||
if v == "" {
|
if v == "" {
|
||||||
l.SetOutput(io.Discard)
|
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
l.SetLevel(logrus.PanicLevel)
|
|
||||||
return l
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
level := slog.LevelInfo
|
||||||
switch v {
|
switch v {
|
||||||
case "2":
|
case "2":
|
||||||
l.SetLevel(logrus.DebugLevel)
|
level = slog.LevelDebug
|
||||||
case "3":
|
case "3":
|
||||||
l.SetLevel(logrus.TraceLevel)
|
level = logging.LevelTrace
|
||||||
default:
|
}
|
||||||
l.SetLevel(logrus.InfoLevel)
|
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// testLogLevelName returns the level name string accepted by logging.ApplyConfig
|
||||||
|
// for the current TEST_LOGS setting. Kept in sync with NewTestLogger.
|
||||||
|
func testLogLevelName() string {
|
||||||
|
switch os.Getenv("TEST_LOGS") {
|
||||||
|
case "2":
|
||||||
|
return "debug"
|
||||||
|
case "3":
|
||||||
|
return "trace"
|
||||||
|
case "":
|
||||||
|
return "info"
|
||||||
|
}
|
||||||
|
return "info"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildTunUDPPacket assembles an IP+UDP packet suitable for Control.InjectTunPacket.
|
||||||
|
// Using UDP here because it's a simpler protocol.
|
||||||
|
func BuildTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) []byte {
|
||||||
|
serialize := make([]gopacket.SerializableLayer, 0)
|
||||||
|
var netLayer gopacket.NetworkLayer
|
||||||
|
if toAddr.Is6() {
|
||||||
|
if !fromAddr.Is6() {
|
||||||
|
panic("Cant send ipv6 to ipv4")
|
||||||
|
}
|
||||||
|
ip := &layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
NextHeader: layers.IPProtocolUDP,
|
||||||
|
SrcIP: fromAddr.Unmap().AsSlice(),
|
||||||
|
DstIP: toAddr.Unmap().AsSlice(),
|
||||||
|
}
|
||||||
|
serialize = append(serialize, ip)
|
||||||
|
netLayer = ip
|
||||||
|
} else {
|
||||||
|
if !fromAddr.Is4() {
|
||||||
|
panic("Cant send ipv4 to ipv6")
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
SrcIP: fromAddr.Unmap().AsSlice(),
|
||||||
|
DstIP: toAddr.Unmap().AsSlice(),
|
||||||
|
}
|
||||||
|
serialize = append(serialize, ip)
|
||||||
|
netLayer = ip
|
||||||
}
|
}
|
||||||
|
|
||||||
return l
|
udp := layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(fromPort),
|
||||||
|
DstPort: layers.UDPPort(toPort),
|
||||||
|
}
|
||||||
|
if err := udp.SetNetworkLayerForChecksum(netLayer); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := gopacket.NewSerializeBuffer()
|
||||||
|
opt := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
serialize = append(serialize, &udp, gopacket.Payload(data))
|
||||||
|
if err := gopacket.SerializeLayers(buffer, opt, serialize...); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer.Bytes()
|
||||||
}
|
}
|
||||||
|
|||||||
51
e2e/leak_test.go
Normal file
51
e2e/leak_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
//go:build e2e_testing
|
||||||
|
// +build e2e_testing
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/cert_test"
|
||||||
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"go.uber.org/goleak"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestNoGoroutineLeaks brings up two nebula instances, completes a tunnel,
|
||||||
|
// stops both, and asserts no goroutines leak past the shutdown. goleak's
|
||||||
|
// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain
|
||||||
|
// before failing the assertion.
|
||||||
|
//
|
||||||
|
// IgnoreCurrent is necessary in the parallelized suite: other tests can
|
||||||
|
// leave goroutines mid-shutdown when this one runs (Stop is async, the
|
||||||
|
// wg.Wait() drain is not blocking on test return). We're checking that
|
||||||
|
// *this* test's setup tears down cleanly, not that the whole suite is
|
||||||
|
// idle at this moment. Intentionally NOT t.Parallel()'d for the same
|
||||||
|
// reason — concurrent test goroutines would always show up.
|
||||||
|
func TestNoGoroutineLeaks(t *testing.T) {
|
||||||
|
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
r.RenderFlow()
|
||||||
|
|
||||||
|
// Settle period: Stop() is non-blocking; the wg-driven goroutines need
|
||||||
|
// a moment to drain. goleak retries internally too, but a short explicit
|
||||||
|
// settle reduces flakes when the suite is busy.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -24,6 +25,19 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// outNatKey is the (from, to) pair used by outNat. Comparable struct, so it works as a map key without the
|
||||||
|
// allocation cost of a string-concat key.
|
||||||
|
type outNatKey struct {
|
||||||
|
from, to netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// fannedPacket pairs a UDP TX packet with its source control so the router can route it after popping from
|
||||||
|
// the fan-in channel.
|
||||||
|
type fannedPacket struct {
|
||||||
|
from *nebula.Control
|
||||||
|
pkt *udp.Packet
|
||||||
|
}
|
||||||
|
|
||||||
type R struct {
|
type R struct {
|
||||||
// Simple map of the ip:port registered on a control to the control
|
// Simple map of the ip:port registered on a control to the control
|
||||||
// Basically a router, right?
|
// Basically a router, right?
|
||||||
@@ -34,12 +48,28 @@ type R struct {
|
|||||||
|
|
||||||
// A last used map, if an inbound packet hit the inNat map then
|
// A last used map, if an inbound packet hit the inNat map then
|
||||||
// all return packets should use the same last used inbound address for the outbound sender
|
// all return packets should use the same last used inbound address for the outbound sender
|
||||||
// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
|
outNat map[outNatKey]netip.AddrPort
|
||||||
outNat map[string]netip.AddrPort
|
|
||||||
|
|
||||||
// A map of vpn ip to the nebula control it belongs to
|
// A map of vpn ip to the nebula control it belongs to
|
||||||
vpnControls map[netip.Addr]*nebula.Control
|
vpnControls map[netip.Addr]*nebula.Control
|
||||||
|
|
||||||
|
// Cached select infrastructure for RouteForAllUntilTxTun.
|
||||||
|
// The controls map is immutable after NewR so the cases are good for the test lifetime.
|
||||||
|
// We only rebuild if a different receiver is asked.
|
||||||
|
selRecvCtl *nebula.Control
|
||||||
|
selCases []reflect.SelectCase
|
||||||
|
selCtls []*nebula.Control
|
||||||
|
|
||||||
|
// Optional fan-in mode for hot-path benchmarks: one forwarder goroutine per control drains UDP TX into udpFanIn,
|
||||||
|
// so RouteForAllUntilTxTun can do a fixed 2-way native select instead of paying reflect.Select per call.
|
||||||
|
// Off by default (would otherwise interleave with tests that use GetFromUDP directly on the same control).
|
||||||
|
// Enabled by EnableFanIn.
|
||||||
|
udpFanIn chan fannedPacket
|
||||||
|
stopFanIn chan struct{}
|
||||||
|
fanInWG sync.WaitGroup
|
||||||
|
fanInMu sync.Mutex
|
||||||
|
fanInOn atomic.Bool
|
||||||
|
|
||||||
ignoreFlows []ignoreFlow
|
ignoreFlows []ignoreFlow
|
||||||
flow []flowEntry
|
flow []flowEntry
|
||||||
|
|
||||||
@@ -119,7 +149,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
|||||||
controls: make(map[netip.AddrPort]*nebula.Control),
|
controls: make(map[netip.AddrPort]*nebula.Control),
|
||||||
vpnControls: make(map[netip.Addr]*nebula.Control),
|
vpnControls: make(map[netip.Addr]*nebula.Control),
|
||||||
inNat: make(map[netip.AddrPort]*nebula.Control),
|
inNat: make(map[netip.AddrPort]*nebula.Control),
|
||||||
outNat: make(map[string]netip.AddrPort),
|
outNat: make(map[outNatKey]netip.AddrPort),
|
||||||
flow: []flowEntry{},
|
flow: []flowEntry{},
|
||||||
ignoreFlows: []ignoreFlow{},
|
ignoreFlows: []ignoreFlow{},
|
||||||
fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
|
fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
|
||||||
@@ -153,8 +183,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-clockSource.C:
|
case <-clockSource.C:
|
||||||
|
r.Lock()
|
||||||
r.renderHostmaps("clock tick")
|
r.renderHostmaps("clock tick")
|
||||||
r.renderFlow()
|
r.renderFlow()
|
||||||
|
r.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -180,15 +212,21 @@ func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
|
|||||||
// RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening.
|
// RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening.
|
||||||
func (r *R) RenderFlow() {
|
func (r *R) RenderFlow() {
|
||||||
r.cancelRender()
|
r.cancelRender()
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
r.renderFlow()
|
r.renderFlow()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected
|
// CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected
|
||||||
func (r *R) CancelFlowLogs() {
|
func (r *R) CancelFlowLogs() {
|
||||||
r.cancelRender()
|
r.cancelRender()
|
||||||
|
r.Lock()
|
||||||
r.flow = nil
|
r.flow = nil
|
||||||
|
r.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// renderFlow writes the flow log to disk. Caller must hold r.Lock. renderFlow reads r.flow / r.additionalGraphs and
|
||||||
|
// the *packet pointers stashed inside, all of which are mutated under the same lock by routing paths.
|
||||||
func (r *R) renderFlow() {
|
func (r *R) renderFlow() {
|
||||||
if r.flow == nil {
|
if r.flow == nil {
|
||||||
return
|
return
|
||||||
@@ -434,68 +472,157 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
|
|||||||
panic("No control for udp tx " + a.String())
|
panic("No control for udp tx " + a.String())
|
||||||
}
|
}
|
||||||
fp := r.unlockedInjectFlow(sender, c, p, false)
|
fp := r.unlockedInjectFlow(sender, c, p, false)
|
||||||
c.InjectUDPPacket(p)
|
c.InjectUDPPacket(p) // copies internally; original is ours to release
|
||||||
fp.WasReceived()
|
fp.WasReceived()
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
|
p.Release()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on receivers tun
|
// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on the receiver's tun.
|
||||||
// If the router doesn't have the nebula controller for that address, we panic
|
// If a control's UDP TX address can't be matched to a registered control, we panic.
|
||||||
|
//
|
||||||
|
// For allocation-sensitive callers (hot-path benchmarks, in particular relay
|
||||||
|
// benches with 3+ controls), call EnableFanIn() first.
|
||||||
func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
|
func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
|
||||||
|
if r.fanInOn.Load() {
|
||||||
|
return r.routeFanIn(receiver)
|
||||||
|
}
|
||||||
|
return r.routeReflect(receiver)
|
||||||
|
}
|
||||||
|
|
||||||
|
// routeFanIn is the alloc-free path used when EnableFanIn is in effect.
|
||||||
|
func (r *R) routeFanIn(receiver *nebula.Control) []byte {
|
||||||
|
tunTx := receiver.GetTunTxChan()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case p := <-tunTx:
|
||||||
|
r.Lock()
|
||||||
|
if r.flow != nil {
|
||||||
|
np := udp.Packet{Data: make([]byte, len(p))}
|
||||||
|
copy(np.Data, p)
|
||||||
|
r.unlockedInjectFlow(receiver, receiver, &np, true)
|
||||||
|
}
|
||||||
|
r.Unlock()
|
||||||
|
return p
|
||||||
|
case fp := <-r.udpFanIn:
|
||||||
|
r.routeUDP(fp.from, fp.pkt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// routeReflect is the default reflect.Select-based path. Pays the boxing allocation per call but doesn't interfere
|
||||||
|
// with tests that pull packets directly from controls' UDP TX channels via GetFromUDP.
|
||||||
|
func (r *R) routeReflect(receiver *nebula.Control) []byte {
|
||||||
|
sc, cm := r.selectCasesFor(receiver)
|
||||||
|
for {
|
||||||
|
x, rx, _ := reflect.Select(sc)
|
||||||
|
if x == 0 {
|
||||||
|
p := rx.Interface().([]byte)
|
||||||
|
r.Lock()
|
||||||
|
if r.flow != nil {
|
||||||
|
np := udp.Packet{Data: make([]byte, len(p))}
|
||||||
|
copy(np.Data, p)
|
||||||
|
r.unlockedInjectFlow(cm[x], cm[x], &np, true)
|
||||||
|
}
|
||||||
|
r.Unlock()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
r.routeUDP(cm[x], rx.Interface().(*udp.Packet))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableFanIn switches RouteForAllUntilTxTun to the alloc-free fan-in path.
|
||||||
|
// One forwarder goroutine per registered control drains UDP TX into a shared channel that RouteForAllUntilTxTun selects
|
||||||
|
// on alongside the receiver's TUN TX channel.
|
||||||
|
func (r *R) EnableFanIn() {
|
||||||
|
r.fanInMu.Lock()
|
||||||
|
defer r.fanInMu.Unlock()
|
||||||
|
if r.fanInOn.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.udpFanIn = make(chan fannedPacket, 32)
|
||||||
|
r.stopFanIn = make(chan struct{})
|
||||||
|
for _, c := range r.controls {
|
||||||
|
r.startFanInWorker(c)
|
||||||
|
}
|
||||||
|
r.fanInOn.Store(true)
|
||||||
|
r.t.Cleanup(r.stopFanInWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// startFanInWorker spawns a goroutine that drains c's UDP TX into r.udpFanIn.
|
||||||
|
func (r *R) startFanInWorker(c *nebula.Control) {
|
||||||
|
r.fanInWG.Add(1)
|
||||||
|
udpTx := c.GetUDPTxChan()
|
||||||
|
go func() {
|
||||||
|
defer r.fanInWG.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.stopFanIn:
|
||||||
|
return
|
||||||
|
case p := <-udpTx:
|
||||||
|
select {
|
||||||
|
case <-r.stopFanIn:
|
||||||
|
p.Release()
|
||||||
|
return
|
||||||
|
case r.udpFanIn <- fannedPacket{from: c, pkt: p}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopFanInWorkers signals the fan-in goroutines to exit and waits for them.
|
||||||
|
func (r *R) stopFanInWorkers() {
|
||||||
|
r.fanInMu.Lock()
|
||||||
|
wasOn := r.fanInOn.Swap(false)
|
||||||
|
r.fanInMu.Unlock()
|
||||||
|
if !wasOn {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
close(r.stopFanIn)
|
||||||
|
r.fanInWG.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// routeUDP forwards a UDP TX packet from the named source control to the destination control derived from p.To,
|
||||||
|
// releasing the source packet after InjectUDPPacket has copied its bytes into a fresh pool slot.
|
||||||
|
func (r *R) routeUDP(from *nebula.Control, p *udp.Packet) {
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
a := from.GetUDPAddr()
|
||||||
|
c := r.getControl(a, p.To, p)
|
||||||
|
if c == nil {
|
||||||
|
panic(fmt.Sprintf("No control for udp tx %s", p.To))
|
||||||
|
}
|
||||||
|
fp := r.unlockedInjectFlow(from, c, p, false)
|
||||||
|
c.InjectUDPPacket(p) // copies internally; original is ours to release
|
||||||
|
fp.WasReceived()
|
||||||
|
p.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectCasesFor returns the SelectCase array used by routeReflect: one slot for the receiver's TUN TX channel followed
|
||||||
|
// by one per control's UDP TX channel. Cached for the test lifetime, only rebuilt if the receiver changes.
|
||||||
|
func (r *R) selectCasesFor(receiver *nebula.Control) ([]reflect.SelectCase, []*nebula.Control) {
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
if r.selRecvCtl == receiver && r.selCases != nil {
|
||||||
|
return r.selCases, r.selCtls
|
||||||
|
}
|
||||||
sc := make([]reflect.SelectCase, len(r.controls)+1)
|
sc := make([]reflect.SelectCase, len(r.controls)+1)
|
||||||
cm := make([]*nebula.Control, len(r.controls)+1)
|
cm := make([]*nebula.Control, len(r.controls)+1)
|
||||||
|
sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(receiver.GetTunTxChan())}
|
||||||
i := 0
|
cm[0] = receiver
|
||||||
sc[i] = reflect.SelectCase{
|
i := 1
|
||||||
Dir: reflect.SelectRecv,
|
|
||||||
Chan: reflect.ValueOf(receiver.GetTunTxChan()),
|
|
||||||
Send: reflect.Value{},
|
|
||||||
}
|
|
||||||
cm[i] = receiver
|
|
||||||
|
|
||||||
i++
|
|
||||||
for _, c := range r.controls {
|
for _, c := range r.controls {
|
||||||
sc[i] = reflect.SelectCase{
|
sc[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan())}
|
||||||
Dir: reflect.SelectRecv,
|
|
||||||
Chan: reflect.ValueOf(c.GetUDPTxChan()),
|
|
||||||
Send: reflect.Value{},
|
|
||||||
}
|
|
||||||
|
|
||||||
cm[i] = c
|
cm[i] = c
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
r.selRecvCtl = receiver
|
||||||
for {
|
r.selCases = sc
|
||||||
x, rx, _ := reflect.Select(sc)
|
r.selCtls = cm
|
||||||
r.Lock()
|
return sc, cm
|
||||||
|
|
||||||
if x == 0 {
|
|
||||||
// we are the tun tx, we can exit
|
|
||||||
p := rx.Interface().([]byte)
|
|
||||||
np := udp.Packet{Data: make([]byte, len(p))}
|
|
||||||
copy(np.Data, p)
|
|
||||||
|
|
||||||
r.unlockedInjectFlow(cm[x], cm[x], &np, true)
|
|
||||||
r.Unlock()
|
|
||||||
return p
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// we are a udp tx, route and continue
|
|
||||||
p := rx.Interface().(*udp.Packet)
|
|
||||||
a := cm[x].GetUDPAddr()
|
|
||||||
c := r.getControl(a, p.To, p)
|
|
||||||
if c == nil {
|
|
||||||
r.Unlock()
|
|
||||||
panic(fmt.Sprintf("No control for udp tx %s", p.To))
|
|
||||||
}
|
|
||||||
fp := r.unlockedInjectFlow(cm[x], c, p, false)
|
|
||||||
c.InjectUDPPacket(p)
|
|
||||||
fp.WasReceived()
|
|
||||||
}
|
|
||||||
r.Unlock()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteExitFunc will call the whatDo func with each udp packet from sender.
|
// RouteExitFunc will call the whatDo func with each udp packet from sender.
|
||||||
@@ -522,6 +649,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
|
|||||||
switch e {
|
switch e {
|
||||||
case ExitNow:
|
case ExitNow:
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
|
p.Release()
|
||||||
return
|
return
|
||||||
|
|
||||||
case RouteAndExit:
|
case RouteAndExit:
|
||||||
@@ -529,6 +657,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
|
|||||||
receiver.InjectUDPPacket(p)
|
receiver.InjectUDPPacket(p)
|
||||||
fp.WasReceived()
|
fp.WasReceived()
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
|
p.Release()
|
||||||
return
|
return
|
||||||
|
|
||||||
case KeepRouting:
|
case KeepRouting:
|
||||||
@@ -541,6 +670,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
|
p.Release()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -641,6 +771,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
|
|||||||
switch e {
|
switch e {
|
||||||
case ExitNow:
|
case ExitNow:
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
|
p.Release()
|
||||||
return
|
return
|
||||||
|
|
||||||
case RouteAndExit:
|
case RouteAndExit:
|
||||||
@@ -648,6 +779,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
|
|||||||
receiver.InjectUDPPacket(p)
|
receiver.InjectUDPPacket(p)
|
||||||
fp.WasReceived()
|
fp.WasReceived()
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
|
p.Release()
|
||||||
return
|
return
|
||||||
|
|
||||||
case KeepRouting:
|
case KeepRouting:
|
||||||
@@ -659,6 +791,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
|
|||||||
panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
|
panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
|
||||||
}
|
}
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
|
p.Release()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -702,19 +835,20 @@ func (r *R) FlushAll() {
|
|||||||
}
|
}
|
||||||
receiver.InjectUDPPacket(p)
|
receiver.InjectUDPPacket(p)
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
|
p.Release()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
|
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
|
||||||
// This is an internal router function, the caller must hold the lock
|
// This is an internal router function, the caller must hold the lock
|
||||||
func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
|
func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
|
||||||
if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
|
if newAddr, ok := r.outNat[outNatKey{from: fromAddr, to: toAddr}]; ok {
|
||||||
p.From = newAddr
|
p.From = newAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
c, ok := r.inNat[toAddr]
|
c, ok := r.inNat[toAddr]
|
||||||
if ok {
|
if ok {
|
||||||
r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
|
r.outNat[outNatKey{from: c.GetUDPAddr(), to: fromAddr}] = toAddr
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,11 +12,14 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDropInactiveTunnels(t *testing.T) {
|
func TestDropInactiveTunnels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
// under ideal conditions
|
// under ideal conditions
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
@@ -61,6 +64,7 @@ func TestDropInactiveTunnels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCertUpgrade(t *testing.T) {
|
func TestCertUpgrade(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
// under ideal conditions
|
// under ideal conditions
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
@@ -155,6 +159,7 @@ func TestCertUpgrade(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCertDowngrade(t *testing.T) {
|
func TestCertDowngrade(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
// under ideal conditions
|
// under ideal conditions
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
@@ -253,6 +258,7 @@ func TestCertDowngrade(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCertMismatchCorrection(t *testing.T) {
|
func TestCertMismatchCorrection(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
// under ideal conditions
|
// under ideal conditions
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
@@ -320,6 +326,7 @@ func TestCertMismatchCorrection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCrossStackRelaysWork(t *testing.T) {
|
func TestCrossStackRelaysWork(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||||
@@ -348,14 +355,14 @@ func TestCrossStackRelaysWork(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
t.Log("reply?")
|
t.Log("reply?")
|
||||||
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")))
|
||||||
p = r.RouteForAllUntilTxTun(myControl)
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
@@ -365,3 +372,107 @@ func TestCrossStackRelaysWork(t *testing.T) {
|
|||||||
//theirControl.Stop()
|
//theirControl.Stop()
|
||||||
//relayControl.Stop()
|
//relayControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloseTunnelAuthenticated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
r.Log("Close the tunnel")
|
||||||
|
myControl.CloseTunnel(theirVpnIpNet[0].Addr(), false)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
myIndexes := len(myControl.GetHostmap().Indexes)
|
||||||
|
theirIndexes := len(theirControl.GetHostmap().Indexes)
|
||||||
|
if myIndexes == 0 && theirIndexes == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
|
||||||
|
if since > time.Second*6 {
|
||||||
|
t.Fatal("Tunnel should have been declared inactive after 2 seconds and before 6 seconds")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
//r.FlushAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logf("Happy path success, tunnels were dropped within %v", time.Since(waitStart))
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
r.Log("Assert another tunnel between me and them works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
hi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
if hi == nil {
|
||||||
|
t.Fatal("There is no hostinfo for this tunnel")
|
||||||
|
}
|
||||||
|
myHi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
if myHi == nil {
|
||||||
|
t.Fatal("There is no hostinfo for my tunnel")
|
||||||
|
}
|
||||||
|
r.Log("It does")
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
hdr := header.H{
|
||||||
|
Version: 1,
|
||||||
|
Type: header.CloseTunnel,
|
||||||
|
Subtype: 0,
|
||||||
|
Reserved: 0,
|
||||||
|
RemoteIndex: hi.RemoteIndex,
|
||||||
|
MessageCounter: 5,
|
||||||
|
}
|
||||||
|
out, err := hdr.Encode(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := &udp.Packet{
|
||||||
|
To: hi.CurrentRemote,
|
||||||
|
From: myHi.CurrentRemote,
|
||||||
|
Data: out,
|
||||||
|
}
|
||||||
|
r.InjectUDPPacket(myControl, theirControl, pkt)
|
||||||
|
r.Log("Injected bogus close tunnel. Let's see!")
|
||||||
|
waitStart = time.Now()
|
||||||
|
for {
|
||||||
|
myIndexes := len(myControl.GetHostmap().Indexes)
|
||||||
|
theirIndexes := len(theirControl.GetHostmap().Indexes)
|
||||||
|
if myIndexes == 0 {
|
||||||
|
t.Fatal("myIndexes should not be 0")
|
||||||
|
}
|
||||||
|
if theirIndexes == 0 {
|
||||||
|
t.Fatal("theirIndexes should not be 0, they should have rejected this bogus packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
|
||||||
|
if since > time.Second*4 {
|
||||||
|
t.Log("The tunnel would have been gone by now")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|||||||
@@ -204,6 +204,12 @@ punchy:
|
|||||||
# Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access.
|
# Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access.
|
||||||
#trusted_cas:
|
#trusted_cas:
|
||||||
#- "ssh public key string"
|
#- "ssh public key string"
|
||||||
|
# sandbox_dir restricts file paths for profiling commands (start-cpu-profile, save-heap-profile,
|
||||||
|
# save-mutex-profile) to the specified directory. Relative paths will be resolved within this directory,
|
||||||
|
# and absolute paths outside of it will be rejected. Default is $TMP/nebula-debug.
|
||||||
|
# The directory is NOT automatically created.
|
||||||
|
# Overriding this to "" is the same as "/" and will allow overwriting any path on the host.
|
||||||
|
#sandbox_dir: /var/tmp/nebula-debug
|
||||||
|
|
||||||
# EXPERIMENTAL: relay support for networks that can't establish direct connections.
|
# EXPERIMENTAL: relay support for networks that can't establish direct connections.
|
||||||
relay:
|
relay:
|
||||||
@@ -327,24 +333,21 @@ tun:
|
|||||||
|
|
||||||
# Configure logging level
|
# Configure logging level
|
||||||
logging:
|
logging:
|
||||||
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
|
# trace, debug, info, warn, or error. Default is info and is reloadable.
|
||||||
#NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some
|
# fatal and panic are accepted for backwards compatibility and map to error.
|
||||||
# scenarios. Debug logging is also CPU intensive and will decrease performance overall.
|
#NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some
|
||||||
# Only enable debug logging while actively investigating an issue.
|
# scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall.
|
||||||
|
# Only enable debug or trace logging while actively investigating an issue.
|
||||||
level: info
|
level: info
|
||||||
# json or text formats currently available. Default is text
|
# json or text formats currently available. Default is text.
|
||||||
format: text
|
format: text
|
||||||
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
|
# Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false.
|
||||||
#disable_timestamp: true
|
#disable_timestamp: true
|
||||||
# timestamp format is specified in Go time format, see:
|
# Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable.
|
||||||
# https://golang.org/pkg/time/#pkg-constants
|
|
||||||
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
|
||||||
# default when `format: text`:
|
|
||||||
# when TTY attached: seconds since beginning of execution
|
|
||||||
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
|
||||||
# As an example, to log as RFC3339 with millisecond precision, set to:
|
|
||||||
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
|
|
||||||
|
|
||||||
|
# The stats section is reloadable. A HUP may change the backend, toggle stats
|
||||||
|
# on or off, switch the listen/host address, or pick up new DNS for the
|
||||||
|
# configured graphite host.
|
||||||
#stats:
|
#stats:
|
||||||
#type: graphite
|
#type: graphite
|
||||||
#prefix: nebula
|
#prefix: nebula
|
||||||
@@ -362,10 +365,12 @@ logging:
|
|||||||
# enables counter metrics for meta packets
|
# enables counter metrics for meta packets
|
||||||
# e.g.: `messages.tx.handshake`
|
# e.g.: `messages.tx.handshake`
|
||||||
# NOTE: `message.{tx,rx}.recv_error` is always emitted
|
# NOTE: `message.{tx,rx}.recv_error` is always emitted
|
||||||
|
# Not reloadable.
|
||||||
#message_metrics: false
|
#message_metrics: false
|
||||||
|
|
||||||
# enables detailed counter metrics for lighthouse packets
|
# enables detailed counter metrics for lighthouse packets
|
||||||
# e.g.: `lighthouse.rx.HostQuery`
|
# e.g.: `lighthouse.rx.HostQuery`
|
||||||
|
# Not reloadable.
|
||||||
#lighthouse_metrics: false
|
#lighthouse_metrics: false
|
||||||
|
|
||||||
# Handshake Manager Settings
|
# Handshake Manager Settings
|
||||||
@@ -423,8 +428,8 @@ firewall:
|
|||||||
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
|
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
|
||||||
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
|
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
|
||||||
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
|
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
|
||||||
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
|
|
||||||
# proto: `any`, `tcp`, `udp`, or `icmp`
|
# proto: `any`, `tcp`, `udp`, or `icmp`
|
||||||
|
# a port specification is ignored if proto is `icmp`
|
||||||
# host: `any` or a literal hostname, ie `test-host`
|
# host: `any` or a literal hostname, ie `test-host`
|
||||||
# group: `any` or a literal group name, ie `default-group`
|
# group: `any` or a literal group name, ie `default-group`
|
||||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/service"
|
"github.com/slackhq/nebula/service"
|
||||||
)
|
)
|
||||||
@@ -64,8 +64,7 @@ pki:
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := logrus.New()
|
logger := logging.NewLogger(os.Stdout)
|
||||||
logger.Out = os.Stdout
|
|
||||||
|
|
||||||
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
203
firewall.go
203
firewall.go
@@ -1,11 +1,13 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -16,7 +18,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -67,7 +68,7 @@ type Firewall struct {
|
|||||||
incomingMetrics firewallMetrics
|
incomingMetrics firewallMetrics
|
||||||
outgoingMetrics firewallMetrics
|
outgoingMetrics firewallMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type firewallMetrics struct {
|
type firewallMetrics struct {
|
||||||
@@ -131,7 +132,7 @@ type firewallLocalCIDR struct {
|
|||||||
|
|
||||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||||
// The certificate provided should be the highest version loaded in memory.
|
// The certificate provided should be the highest version loaded in memory.
|
||||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
||||||
//TODO: error on 0 duration
|
//TODO: error on 0 duration
|
||||||
var tmin, tmax time.Duration
|
var tmin, tmax time.Duration
|
||||||
|
|
||||||
@@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||||
certificate := cs.getCertificate(cert.Version2)
|
certificate := cs.getCertificate(cert.Version2)
|
||||||
if certificate == nil {
|
if certificate == nil {
|
||||||
certificate = cs.getCertificate(cert.Version1)
|
certificate = cs.getCertificate(cert.Version1)
|
||||||
@@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
case "drop":
|
case "drop":
|
||||||
fw.InSendReject = false
|
fw.InSendReject = false
|
||||||
default:
|
default:
|
||||||
l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`")
|
l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction)
|
||||||
fw.InSendReject = false
|
fw.InSendReject = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
case "drop":
|
case "drop":
|
||||||
fw.OutSendReject = false
|
fw.OutSendReject = false
|
||||||
default:
|
default:
|
||||||
l.WithField("action", inboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`")
|
l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction)
|
||||||
fw.OutSendReject = false
|
fw.OutSendReject = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -249,20 +250,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
|
|
||||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
||||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
|
||||||
ruleString := fmt.Sprintf(
|
|
||||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
|
||||||
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
|
|
||||||
)
|
|
||||||
f.rules += ruleString + "\n"
|
|
||||||
|
|
||||||
direction := "incoming"
|
|
||||||
if !incoming {
|
|
||||||
direction = "outgoing"
|
|
||||||
}
|
|
||||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
|
|
||||||
Info("Firewall rule added")
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ft *FirewallTable
|
ft *FirewallTable
|
||||||
fp firewallPort
|
fp firewallPort
|
||||||
@@ -280,6 +267,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
case firewall.ProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
fp = ft.UDP
|
fp = ft.UDP
|
||||||
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||||
|
//ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided
|
||||||
|
if startPort != firewall.PortAny {
|
||||||
|
f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort)
|
||||||
|
}
|
||||||
|
startPort = firewall.PortAny
|
||||||
|
endPort = firewall.PortAny
|
||||||
fp = ft.ICMP
|
fp = ft.ICMP
|
||||||
case firewall.ProtoAny:
|
case firewall.ProtoAny:
|
||||||
fp = ft.AnyProto
|
fp = ft.AnyProto
|
||||||
@@ -287,6 +280,21 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
return fmt.Errorf("unknown protocol %v", proto)
|
return fmt.Errorf("unknown protocol %v", proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||||
|
ruleString := fmt.Sprintf(
|
||||||
|
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
||||||
|
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
|
||||||
|
)
|
||||||
|
f.rules += ruleString + "\n"
|
||||||
|
|
||||||
|
direction := "incoming"
|
||||||
|
if !incoming {
|
||||||
|
direction = "outgoing"
|
||||||
|
}
|
||||||
|
f.l.Info("Firewall rule added",
|
||||||
|
"firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha},
|
||||||
|
)
|
||||||
|
|
||||||
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string {
|
|||||||
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||||
var table string
|
var table string
|
||||||
if inbound {
|
if inbound {
|
||||||
table = "firewall.inbound"
|
table = "firewall.inbound"
|
||||||
@@ -349,24 +357,31 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
sPort = r.Port
|
sPort = r.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
startPort, endPort, err := parsePort(sPort)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var proto uint8
|
var proto uint8
|
||||||
|
var startPort, endPort int32
|
||||||
switch r.Proto {
|
switch r.Proto {
|
||||||
case "any":
|
case "any":
|
||||||
proto = firewall.ProtoAny
|
proto = firewall.ProtoAny
|
||||||
|
startPort, endPort, err = parsePort(sPort)
|
||||||
case "tcp":
|
case "tcp":
|
||||||
proto = firewall.ProtoTCP
|
proto = firewall.ProtoTCP
|
||||||
|
startPort, endPort, err = parsePort(sPort)
|
||||||
case "udp":
|
case "udp":
|
||||||
proto = firewall.ProtoUDP
|
proto = firewall.ProtoUDP
|
||||||
|
startPort, endPort, err = parsePort(sPort)
|
||||||
case "icmp":
|
case "icmp":
|
||||||
proto = firewall.ProtoICMP
|
proto = firewall.ProtoICMP
|
||||||
|
startPort = firewall.PortAny
|
||||||
|
endPort = firewall.PortAny
|
||||||
|
if sPort != "" {
|
||||||
|
l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
|
||||||
|
}
|
||||||
|
|
||||||
if r.Cidr != "" && r.Cidr != "any" {
|
if r.Cidr != "" && r.Cidr != "any" {
|
||||||
_, err = netip.ParsePrefix(r.Cidr)
|
_, err = netip.ParsePrefix(r.Cidr)
|
||||||
@@ -383,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
}
|
}
|
||||||
|
|
||||||
if warning := r.sanity(); warning != nil {
|
if warning := r.sanity(); warning != nil {
|
||||||
l.Warnf("%s rule #%v; %s", table, i, warning)
|
l.Warn("firewall rule sanity check",
|
||||||
|
"table", table,
|
||||||
|
"rule", i,
|
||||||
|
"warning", warning,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
||||||
@@ -467,7 +486,7 @@ func (f *Firewall) metrics(incoming bool) firewallMetrics {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
|
// Destroy cleans up any known cyclical references so the object can be freed by GC. This should be called if a new
|
||||||
// firewall object is created
|
// firewall object is created
|
||||||
func (f *Firewall) Destroy() {
|
func (f *Firewall) Destroy() {
|
||||||
//TODO: clean references if/when needed
|
//TODO: clean references if/when needed
|
||||||
@@ -515,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
|
|
||||||
// We now know which firewall table to check against
|
// We now know which firewall table to check against
|
||||||
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
h.logger(f.l).
|
h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset",
|
||||||
WithField("fwPacket", fp).
|
"fwPacket", fp,
|
||||||
WithField("incoming", c.incoming).
|
"incoming", c.incoming,
|
||||||
WithField("rulesVersion", f.rulesVersion).
|
"rulesVersion", f.rulesVersion,
|
||||||
WithField("oldRulesVersion", c.rulesVersion).
|
"oldRulesVersion", c.rulesVersion,
|
||||||
Debugln("dropping old conntrack entry, does not match new ruleset")
|
)
|
||||||
}
|
}
|
||||||
delete(conntrack.Conns, fp)
|
delete(conntrack.Conns, fp)
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
h.logger(f.l).
|
h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset",
|
||||||
WithField("fwPacket", fp).
|
"fwPacket", fp,
|
||||||
WithField("incoming", c.incoming).
|
"incoming", c.incoming,
|
||||||
WithField("rulesVersion", f.rulesVersion).
|
"rulesVersion", f.rulesVersion,
|
||||||
WithField("oldRulesVersion", c.rulesVersion).
|
"oldRulesVersion", c.rulesVersion,
|
||||||
Debugln("keeping old conntrack entry, does match new ruleset")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.rulesVersion = f.rulesVersion
|
c.rulesVersion = f.rulesVersion
|
||||||
@@ -660,6 +679,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this branch is here to catch traffic from FirewallTable.Any.match and FirewallTable.ICMP.match
|
||||||
|
if p.Protocol == firewall.ProtoICMP || p.Protocol == firewall.ProtoICMPv6 {
|
||||||
|
// port numbers are re-used for connection tracking of ICMP,
|
||||||
|
// but we don't want to actually filter on them.
|
||||||
|
return fp[firewall.PortAny].match(p, c, caPool)
|
||||||
|
}
|
||||||
|
|
||||||
var port int32
|
var port int32
|
||||||
|
|
||||||
if p.Fragment {
|
if p.Fragment {
|
||||||
@@ -804,10 +830,8 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, group := range groups {
|
if slices.Contains(groups, "any") {
|
||||||
if group == "any" {
|
return true
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if host == "any" {
|
if host == "any" {
|
||||||
@@ -917,7 +941,7 @@ type rule struct {
|
|||||||
CASha string
|
CASha string
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) {
|
||||||
r := rule{}
|
r := rule{}
|
||||||
|
|
||||||
m, ok := p.(map[string]any)
|
m, ok := p.(map[string]any)
|
||||||
@@ -948,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
|||||||
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
l.Warn("group was an array with a single value, converting to simple value",
|
||||||
|
"table", table,
|
||||||
|
"rule", i,
|
||||||
|
)
|
||||||
m["group"] = v[0]
|
m["group"] = v[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1018,54 +1045,56 @@ func (r *rule) sanity() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.Code != "" {
|
||||||
|
return fmt.Errorf("code specified as [%s]. Support for 'code' will be dropped in a future release, as it has never been functional", r.Code)
|
||||||
|
}
|
||||||
|
|
||||||
//todo alert on cidr-any
|
//todo alert on cidr-any
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parsePort(s string) (startPort, endPort int32, err error) {
|
func parsePort(s string) (int32, int32, error) {
|
||||||
|
var err error
|
||||||
|
const notAPort int32 = -2
|
||||||
if s == "any" {
|
if s == "any" {
|
||||||
startPort = firewall.PortAny
|
return firewall.PortAny, firewall.PortAny, nil
|
||||||
endPort = firewall.PortAny
|
}
|
||||||
|
if s == "fragment" {
|
||||||
} else if s == "fragment" {
|
return firewall.PortFragment, firewall.PortFragment, nil
|
||||||
startPort = firewall.PortFragment
|
}
|
||||||
endPort = firewall.PortFragment
|
if !strings.Contains(s, `-`) {
|
||||||
|
|
||||||
} else if strings.Contains(s, `-`) {
|
|
||||||
sPorts := strings.SplitN(s, `-`, 2)
|
|
||||||
sPorts[0] = strings.Trim(sPorts[0], " ")
|
|
||||||
sPorts[1] = strings.Trim(sPorts[1], " ")
|
|
||||||
|
|
||||||
if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
|
|
||||||
return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
rStartPort, err := strconv.Atoi(sPorts[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
rEndPort, err := strconv.Atoi(sPorts[1])
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
startPort = int32(rStartPort)
|
|
||||||
endPort = int32(rEndPort)
|
|
||||||
|
|
||||||
if startPort == firewall.PortAny {
|
|
||||||
endPort = firewall.PortAny
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
rPort, err := strconv.Atoi(s)
|
rPort, err := strconv.Atoi(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, fmt.Errorf("was not a number; `%s`", s)
|
return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s)
|
||||||
}
|
}
|
||||||
startPort = int32(rPort)
|
return int32(rPort), int32(rPort), nil
|
||||||
endPort = startPort
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
sPorts := strings.SplitN(s, `-`, 2)
|
||||||
|
for i := range sPorts {
|
||||||
|
sPorts[i] = strings.Trim(sPorts[i], " ")
|
||||||
|
}
|
||||||
|
if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
|
||||||
|
return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
rStartPort, err := strconv.Atoi(sPorts[0])
|
||||||
|
if err != nil {
|
||||||
|
return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
rEndPort, err := strconv.Atoi(sPorts[1])
|
||||||
|
if err != nil {
|
||||||
|
return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
startPort := int32(rStartPort)
|
||||||
|
endPort := int32(rEndPort)
|
||||||
|
|
||||||
|
if startPort == firewall.PortAny {
|
||||||
|
endPort = firewall.PortAny
|
||||||
|
}
|
||||||
|
|
||||||
|
return startPort, endPort, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||||
@@ -15,41 +15,49 @@ type ConntrackCacheTicker struct {
|
|||||||
cacheV uint64
|
cacheV uint64
|
||||||
cacheTick atomic.Uint64
|
cacheTick atomic.Uint64
|
||||||
|
|
||||||
|
l *slog.Logger
|
||||||
cache ConntrackCache
|
cache ConntrackCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker {
|
||||||
if d == 0 {
|
if d == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &ConntrackCacheTicker{
|
c := &ConntrackCacheTicker{
|
||||||
|
l: l,
|
||||||
cache: ConntrackCache{},
|
cache: ConntrackCache{},
|
||||||
}
|
}
|
||||||
|
|
||||||
go c.tick(d)
|
go c.tick(ctx, d)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
|
||||||
|
t := time.NewTicker(d)
|
||||||
|
defer t.Stop()
|
||||||
for {
|
for {
|
||||||
time.Sleep(d)
|
select {
|
||||||
c.cacheTick.Add(1)
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-t.C:
|
||||||
|
c.cacheTick.Add(1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get checks if the cache ticker has moved to the next version before returning
|
// Get checks if the cache ticker has moved to the next version before returning
|
||||||
// the map. If it has moved, we reset the map.
|
// the map. If it has moved, we reset the map.
|
||||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
func (c *ConntrackCacheTicker) Get() ConntrackCache {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
||||||
c.cacheV = tick
|
c.cacheV = tick
|
||||||
if ll := len(c.cache); ll > 0 {
|
if ll := len(c.cache); ll > 0 {
|
||||||
if l.Level == logrus.DebugLevel {
|
if c.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
c.l.Debug("resetting conntrack cache", "len", ll)
|
||||||
}
|
}
|
||||||
c.cache = make(ConntrackCache, ll)
|
c.cache = make(ConntrackCache, ll)
|
||||||
}
|
}
|
||||||
|
|||||||
69
firewall/cache_test.go
Normal file
69
firewall/cache_test.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The tests below pin the log format produced by ConntrackCacheTicker.Get
|
||||||
|
// so changes cannot silently break what operators are grepping for. The
|
||||||
|
// ticker's internal state (cache + cacheTick) is poked directly to avoid
|
||||||
|
// racing a goroutine-driven tick in tests.
|
||||||
|
|
||||||
|
func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker {
|
||||||
|
t.Helper()
|
||||||
|
c := &ConntrackCacheTicker{
|
||||||
|
l: l,
|
||||||
|
cache: make(ConntrackCache, cacheLen),
|
||||||
|
}
|
||||||
|
for i := 0; i < cacheLen; i++ {
|
||||||
|
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
|
||||||
|
}
|
||||||
|
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 3)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 2)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 5)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.Empty(t, buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 0)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.Empty(t, buf.String())
|
||||||
|
}
|
||||||
@@ -23,7 +23,10 @@ const (
|
|||||||
type Packet struct {
|
type Packet struct {
|
||||||
LocalAddr netip.Addr
|
LocalAddr netip.Addr
|
||||||
RemoteAddr netip.Addr
|
RemoteAddr netip.Addr
|
||||||
LocalPort uint16
|
// LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP.
|
||||||
|
LocalPort uint16
|
||||||
|
// RemotePort is the source port for incoming traffic, or the destination port for outgoing.
|
||||||
|
// For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier
|
||||||
RemotePort uint16
|
RemotePort uint16
|
||||||
Protocol uint8
|
Protocol uint8
|
||||||
Fragment bool
|
Fragment bool
|
||||||
@@ -47,6 +50,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
|
|||||||
proto = "tcp"
|
proto = "tcp"
|
||||||
case ProtoICMP:
|
case ProtoICMP:
|
||||||
proto = "icmp"
|
proto = "icmp"
|
||||||
|
case ProtoICMPv6:
|
||||||
|
proto = "icmpv6"
|
||||||
case ProtoUDP:
|
case ProtoUDP:
|
||||||
proto = "udp"
|
proto = "udp"
|
||||||
default:
|
default:
|
||||||
|
|||||||
262
firewall_test.go
262
firewall_test.go
@@ -3,13 +3,13 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_AddRule(t *testing.T) {
|
func TestFirewall_AddRule(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
c := &dummyCert{}
|
c := &dummyCert{}
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
@@ -87,9 +86,10 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
//no matter what port is given for icmp, it should end up as "any"
|
||||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any)
|
||||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups)
|
||||||
|
assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
|
||||||
@@ -176,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
@@ -253,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropV6(t *testing.T) {
|
func TestFirewall_DropV6(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
@@ -484,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop2(t *testing.T) {
|
func TestFirewall_Drop2(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -543,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3(t *testing.T) {
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -632,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3V6(t *testing.T) {
|
func TestFirewall_Drop3V6(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
|
|
||||||
@@ -670,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -734,10 +728,152 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
|
network := netip.MustParsePrefix("1.2.3.4/24")
|
||||||
|
|
||||||
|
c := cert.CachedCertificate{
|
||||||
|
Certificate: &dummyCert{
|
||||||
|
name: "host1",
|
||||||
|
networks: []netip.Prefix{network},
|
||||||
|
groups: []string{"default-group"},
|
||||||
|
issuer: "signer-shasum",
|
||||||
|
},
|
||||||
|
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||||
|
}
|
||||||
|
h := HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: &c,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
|
}
|
||||||
|
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||||
|
|
||||||
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
|
templ := firewall.Packet{
|
||||||
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
|
Protocol: firewall.ProtoICMP,
|
||||||
|
Fragment: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ICMP allowed", func(t *testing.T) {
|
||||||
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||||
|
t.Run("zero ports", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0
|
||||||
|
p.RemotePort = 0
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||||
|
//now also allow outbound
|
||||||
|
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonzero ports", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0xabcd
|
||||||
|
p.RemotePort = 0x1234
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||||
|
//now also allow outbound
|
||||||
|
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Any proto, some ports allowed", func(t *testing.T) {
|
||||||
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
|
||||||
|
t.Run("zero ports, still blocked", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0
|
||||||
|
p.RemotePort = 0
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
//now also allow outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonzero ports, still blocked", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0xabcd
|
||||||
|
p.RemotePort = 0x1234
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
//now also allow outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 80
|
||||||
|
p.RemotePort = 80
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
//now also allow outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("Any proto, any port", func(t *testing.T) {
|
||||||
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||||
|
t.Run("zero ports, allowed", func(t *testing.T) {
|
||||||
|
resetConntrack(fw)
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0
|
||||||
|
p.RemotePort = 0
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||||
|
//now also allow outbound
|
||||||
|
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonzero ports, allowed", func(t *testing.T) {
|
||||||
|
resetConntrack(fw)
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0xabcd
|
||||||
|
p.RemotePort = 0x1234
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||||
|
//now also allow outbound
|
||||||
|
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||||
|
//different ID is blocked
|
||||||
|
p.RemotePort++
|
||||||
|
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||||
|
ob := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||||
|
|
||||||
@@ -897,56 +1033,56 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
// Test a bad rule definition
|
// Test a bad rule definition
|
||||||
c := &dummyCert{}
|
c := &dummyCert{}
|
||||||
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil, "aes")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||||
|
|
||||||
// Test both port and code
|
// Test both port and code
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||||
|
|
||||||
// Test missing host, group, cidr, ca_name and ca_sha
|
// Test missing host, group, cidr, ca_name and ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||||
|
|
||||||
// Test code/port error
|
// Test code/port error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||||
|
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||||
|
|
||||||
// Test proto error
|
// Test proto error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||||
|
|
||||||
// Test cidr parse error
|
// Test cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test local_cidr parse error
|
// Test local_cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test both group and groups
|
// Test both group and groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||||
@@ -955,28 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
// Test adding tcp rule
|
// Test adding tcp rule
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
mf := &mockFirewall{}
|
mf := &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding udp rule
|
// Test adding udp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding icmp rule
|
// Test adding icmp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
|
// Test adding icmp rule no port
|
||||||
|
conf = config.NewC(test.NewLogger())
|
||||||
|
mf = &mockFirewall{}
|
||||||
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
|
||||||
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding any rule
|
// Test adding any rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
@@ -984,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test adding rule with cidr
|
// Test adding rule with cidr
|
||||||
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with local_cidr
|
// Test adding rule with local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
@@ -999,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test adding rule with cidr ipv6
|
// Test adding rule with cidr ipv6
|
||||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
cidr6 := netip.MustParsePrefix("fd00::/8")
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any cidr
|
// Test adding rule with any cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with junk cidr
|
// Test adding rule with junk cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
||||||
|
|
||||||
// Test adding rule with local_cidr ipv6
|
// Test adding rule with local_cidr ipv6
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any local_cidr
|
// Test adding rule with any local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with junk local_cidr
|
// Test adding rule with junk local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_name
|
// Test adding rule with ca_name
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
|
||||||
|
|
||||||
// Test single group
|
// Test single group
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test single groups
|
// Test single groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test multiple AND groups
|
// Test multiple AND groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test Add error
|
// Test Add error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
mf.nextCallReturn = errors.New("test error")
|
mf.nextCallReturn = errors.New("test error")
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
@@ -1082,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRule(t *testing.T) {
|
func TestFirewall_convertRule(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
// Ensure group array of 1 is converted and a warning is printed
|
// Ensure group array of 1 is converted and a warning is printed
|
||||||
c := map[string]any{
|
c := map[string]any{
|
||||||
@@ -1092,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
r, err := convertRule(l, c, "test", 1)
|
||||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value")
|
||||||
|
assert.Contains(t, ob.String(), "table=test")
|
||||||
|
assert.Contains(t, ob.String(), "rule=1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"group1"}, r.Groups)
|
assert.Equal(t, []string{"group1"}, r.Groups)
|
||||||
|
|
||||||
@@ -1118,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRuleSanity(t *testing.T) {
|
func TestFirewall_convertRuleSanity(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
noWarningPlease := []map[string]any{
|
noWarningPlease := []map[string]any{
|
||||||
{"group": "group1"},
|
{"group": "group1"},
|
||||||
@@ -1234,7 +1377,7 @@ type testsetup struct {
|
|||||||
fw *Firewall
|
fw *Firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup {
|
||||||
c := dummyCert{
|
c := dummyCert{
|
||||||
name: "me",
|
name: "me",
|
||||||
networks: myPrefixes,
|
networks: myPrefixes,
|
||||||
@@ -1245,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
|
|||||||
return newSetupFromCert(t, l, c)
|
return newSetupFromCert(t, l, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup {
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
for _, prefix := range c.Networks() {
|
for _, prefix := range c.Networks() {
|
||||||
myVpnNetworksTable.Insert(prefix)
|
myVpnNetworksTable.Insert(prefix)
|
||||||
@@ -1262,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
|||||||
|
|
||||||
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
||||||
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
||||||
|
|||||||
25
go.mod
25
go.mod
@@ -1,9 +1,10 @@
|
|||||||
module github.com/slackhq/nebula
|
module github.com/slackhq/nebula
|
||||||
|
|
||||||
go 1.25
|
go 1.25.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
dario.cat/mergo v1.0.2
|
dario.cat/mergo v1.0.2
|
||||||
|
filippo.io/bigmod v0.1.0
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
||||||
github.com/armon/go-radix v1.0.0
|
github.com/armon/go-radix v1.0.0
|
||||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||||
@@ -12,26 +13,26 @@ require (
|
|||||||
github.com/gogo/protobuf v1.3.2
|
github.com/gogo/protobuf v1.3.2
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/kardianos/service v1.2.4
|
github.com/kardianos/service v1.2.4
|
||||||
github.com/miekg/dns v1.1.70
|
github.com/miekg/dns v1.1.72
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
github.com/miekg/pkcs11 v1.1.2
|
||||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||||
github.com/sirupsen/logrus v1.9.4
|
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
|
go.uber.org/goleak v1.3.0
|
||||||
go.yaml.in/yaml/v3 v3.0.4
|
go.yaml.in/yaml/v3 v3.0.4
|
||||||
golang.org/x/crypto v0.47.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||||
golang.org/x/net v0.49.0
|
golang.org/x/net v0.52.0
|
||||||
golang.org/x/sync v0.19.0
|
golang.org/x/sync v0.20.0
|
||||||
golang.org/x/sys v0.40.0
|
golang.org/x/sys v0.43.0
|
||||||
golang.org/x/term v0.39.0
|
golang.org/x/term v0.42.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.6.1
|
||||||
google.golang.org/protobuf v1.36.11
|
google.golang.org/protobuf v1.36.11
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||||
@@ -49,7 +50,7 @@ require (
|
|||||||
github.com/prometheus/procfs v0.16.1 // indirect
|
github.com/prometheus/procfs v0.16.1 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||||
golang.org/x/mod v0.31.0 // indirect
|
golang.org/x/mod v0.34.0 // indirect
|
||||||
golang.org/x/time v0.5.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.40.0 // indirect
|
golang.org/x/tools v0.43.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
44
go.sum
44
go.sum
@@ -1,6 +1,8 @@
|
|||||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||||
|
filippo.io/bigmod v0.1.0 h1:UNzDk7y9ADKST+axd9skUpBQeW7fG2KrTZyOE4uGQy8=
|
||||||
|
filippo.io/bigmod v0.1.0/go.mod h1:OjOXDNlClLblvXdwgFFOQFJEocLhhtai8vGLy0JCZlI=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||||
@@ -83,10 +85,10 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA=
|
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
|
||||||
github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
|
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
github.com/miekg/pkcs11 v1.1.2 h1:/VxmeAX5qU6Q3EwafypogwWbYryHFmF2RpkJmw3m4MQ=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
github.com/miekg/pkcs11 v1.1.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||||
@@ -131,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
|
|||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||||
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
||||||
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
|
|
||||||
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
|
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
||||||
@@ -162,16 +162,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -191,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
@@ -223,8 +223,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
|||||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
@@ -233,8 +233,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
|
|||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.6.1 h1:XMaKojH1Hs/raMrmnir4n35nTvzvWj7NmSYzHn2F4qU=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
golang.zx2c4.com/wireguard/windows v0.6.1/go.mod h1:04aqInu5GYuTFvMuDw/rKBAF7mHrltW/3rekpfbbZDM=
|
||||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||||
|
|||||||
57
handshake/credential.go
Normal file
57
handshake/credential.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
|
||||||
|
"github.com/flynn/noise"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Credential holds everything needed to participate in a handshake
|
||||||
|
// at a given cert version. Version and Curve are read from Cert; the public
|
||||||
|
// half of the static keypair likewise comes from Cert.PublicKey().
|
||||||
|
type Credential struct {
|
||||||
|
Cert cert.Certificate // the certificate
|
||||||
|
Bytes []byte // pre-marshaled certificate bytes
|
||||||
|
privateKey []byte // static private key (public half lives in Cert)
|
||||||
|
cipherSuite noise.CipherSuite // pre-built cipher suite (DH + cipher + hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCredential creates a Credential with all material needed for handshake
|
||||||
|
// participation. The cipherSuite should be pre-built by the caller with the
|
||||||
|
// appropriate DH function, cipher, and hash.
|
||||||
|
func NewCredential(
|
||||||
|
c cert.Certificate,
|
||||||
|
hsBytes []byte,
|
||||||
|
privateKey []byte,
|
||||||
|
cipherSuite noise.CipherSuite,
|
||||||
|
) *Credential {
|
||||||
|
return &Credential{
|
||||||
|
Cert: c,
|
||||||
|
Bytes: hsBytes,
|
||||||
|
privateKey: privateKey,
|
||||||
|
cipherSuite: cipherSuite,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildHandshakeState creates a noise.HandshakeState from this credential.
|
||||||
|
func (hc *Credential) buildHandshakeState(initiator bool, pattern noise.HandshakePattern) (*noise.HandshakeState, error) {
|
||||||
|
return noise.NewHandshakeState(noise.Config{
|
||||||
|
CipherSuite: hc.cipherSuite,
|
||||||
|
Random: rand.Reader,
|
||||||
|
Pattern: pattern,
|
||||||
|
Initiator: initiator,
|
||||||
|
StaticKeypair: noise.DHKey{Private: hc.privateKey, Public: hc.Cert.PublicKey()},
|
||||||
|
PresharedKey: []byte{},
|
||||||
|
PresharedKeyPlacement: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCredentialFunc returns the handshake credential for the given version,
|
||||||
|
// or nil if that version is not available.
|
||||||
|
//
|
||||||
|
// Implementations must return credentials drawn from a snapshot stable for
|
||||||
|
// the lifetime of any single Machine. The Machine may call this multiple
|
||||||
|
// times during a handshake (e.g. when negotiating to the peer's version)
|
||||||
|
// and assumes the underlying static keypair is consistent across calls.
|
||||||
|
type GetCredentialFunc func(v cert.Version) *Credential
|
||||||
21
handshake/errors.go
Normal file
21
handshake/errors.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInitiateOnResponder = errors.New("initiate called on responder")
|
||||||
|
ErrInitiateAlreadyCalled = errors.New("initiate already called")
|
||||||
|
ErrInitiateNotCalled = errors.New("initiate must be called before ProcessPacket for initiators")
|
||||||
|
ErrPacketTooShort = errors.New("packet too short")
|
||||||
|
ErrPublicKeyMismatch = errors.New("public key mismatch between certificate and handshake")
|
||||||
|
ErrIncompleteHandshake = errors.New("handshake completed without receiving required content")
|
||||||
|
ErrMachineFailed = errors.New("handshake machine has failed")
|
||||||
|
ErrUnknownSubtype = errors.New("unknown handshake subtype")
|
||||||
|
ErrMissingContent = errors.New("expected handshake content but message was empty")
|
||||||
|
ErrUnexpectedContent = errors.New("received unexpected handshake content")
|
||||||
|
ErrIndexAllocation = errors.New("failed to allocate local index")
|
||||||
|
ErrNoCredential = errors.New("no handshake credential available for cert version")
|
||||||
|
ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key")
|
||||||
|
ErrMultiMessageUnsupported = errors.New("multi-message handshake patterns are not yet supported by the manager")
|
||||||
|
ErrSubtypeMismatch = errors.New("packet subtype does not match handshake machine subtype")
|
||||||
|
)
|
||||||
37
handshake/handshake.proto
Normal file
37
handshake/handshake.proto
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// This file documents the wire format the nebula handshake speaks. It is
|
||||||
|
// not run through protoc; the encoder/decoder in payload.go is hand-written
|
||||||
|
// against this shape directly to keep the parser narrow and panic-free.
|
||||||
|
//
|
||||||
|
// Any change to the wire format must be reflected here, and adding a new
|
||||||
|
// field requires updating MarshalPayload / unmarshalPayloadDetails together
|
||||||
|
// with the field-uniqueness and wire-type checks in those functions.
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
package nebula.handshake;
|
||||||
|
|
||||||
|
message NebulaHandshake {
|
||||||
|
NebulaHandshakeDetails Details = 1;
|
||||||
|
bytes Hmac = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message NebulaHandshakeDetails {
|
||||||
|
bytes Cert = 1;
|
||||||
|
uint32 InitiatorIndex = 2;
|
||||||
|
uint32 ResponderIndex = 3;
|
||||||
|
// Cookie was reserved for an anti-DoS mechanism that was never
|
||||||
|
// implemented. No released version of nebula has ever populated it; the
|
||||||
|
// hand-written parser silently skips it on read.
|
||||||
|
uint64 Cookie = 4 [deprecated = true];
|
||||||
|
uint64 Time = 5;
|
||||||
|
uint32 CertVersion = 8;
|
||||||
|
|
||||||
|
MultiPortDetails InitiatorMultiPort = 6;
|
||||||
|
MultiPortDetails ResponderMultiPort = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MultiPortDetails {
|
||||||
|
bool RxSupported = 1;
|
||||||
|
bool TxSupported = 2;
|
||||||
|
uint32 BasePort = 3;
|
||||||
|
uint32 TotalPorts = 4;
|
||||||
|
}
|
||||||
116
handshake/helpers_test.go
Normal file
116
handshake/helpers_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/flynn/noise"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
ct "github.com/slackhq/nebula/cert_test"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testCertState holds cert material for a test peer.
|
||||||
|
type testCertState struct {
|
||||||
|
version cert.Version
|
||||||
|
creds map[cert.Version]*Credential
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testCertState) getCredential(v cert.Version) *Credential {
|
||||||
|
return s.creds[v]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestCertState(
|
||||||
|
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
|
||||||
|
) *testCertState {
|
||||||
|
return newTestCertStateWithCipher(t, ca, caKey, name, networks, noise.CipherChaChaPoly)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestCertStateWithCipher(
|
||||||
|
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
|
||||||
|
cipher noise.CipherFunc,
|
||||||
|
) *testCertState {
|
||||||
|
t.Helper()
|
||||||
|
c, _, rawPrivKey, _ := ct.NewTestCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
|
||||||
|
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawPrivKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
hsBytes, err := c.MarshalForHandshakes()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ncs := noise.NewCipherSuite(noise.DH25519, cipher, noise.HashSHA256)
|
||||||
|
return &testCertState{
|
||||||
|
version: cert.Version2,
|
||||||
|
creds: map[cert.Version]*Credential{
|
||||||
|
cert.Version2: NewCredential(c, hsBytes, priv, ncs),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testVerifier(pool *cert.CAPool) CertVerifier {
|
||||||
|
return func(c cert.Certificate) (*cert.CachedCertificate, error) {
|
||||||
|
return pool.VerifyCertificate(time.Now(), c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestMachine(
|
||||||
|
t *testing.T,
|
||||||
|
cs *testCertState,
|
||||||
|
verifier CertVerifier,
|
||||||
|
initiator bool,
|
||||||
|
localIndex uint32,
|
||||||
|
) *Machine {
|
||||||
|
t.Helper()
|
||||||
|
m, err := NewMachine(
|
||||||
|
cs.version, cs.getCredential,
|
||||||
|
verifier, func() (uint32, error) { return localIndex, nil },
|
||||||
|
initiator, header.HandshakeIXPSK0,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func initiateHandshake(
|
||||||
|
t *testing.T,
|
||||||
|
initCS *testCertState, initVerifier CertVerifier,
|
||||||
|
respCS *testCertState, respVerifier CertVerifier,
|
||||||
|
) (initM, respM *Machine, respResult *Result, resp []byte, err error) {
|
||||||
|
t.Helper()
|
||||||
|
initM = newTestMachine(t, initCS, initVerifier, true, 100)
|
||||||
|
msg1, merr := initM.Initiate(nil)
|
||||||
|
require.NoError(t, merr)
|
||||||
|
|
||||||
|
respM = newTestMachine(t, respCS, respVerifier, false, 200)
|
||||||
|
resp, respResult, err = respM.ProcessPacket(nil, msg1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func doFullHandshake(
|
||||||
|
t *testing.T, initCS, respCS *testCertState, caPool *cert.CAPool,
|
||||||
|
) (initResult, respResult *Result) {
|
||||||
|
t.Helper()
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
|
||||||
|
initM := newTestMachine(t, initCS, v, true, 1000)
|
||||||
|
respM := newTestMachine(t, respCS, v, false, 2000)
|
||||||
|
|
||||||
|
msg1, err := initM.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, respResult, err := respM.ProcessPacket(nil, msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, respResult)
|
||||||
|
require.NotEmpty(t, resp)
|
||||||
|
|
||||||
|
_, initResult, err = initM.ProcessPacket(nil, resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, initResult)
|
||||||
|
|
||||||
|
return initResult, respResult
|
||||||
|
}
|
||||||
444
handshake/machine.go
Normal file
444
handshake/machine.go
Normal file
@@ -0,0 +1,444 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/flynn/noise"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IndexAllocator is called by the Machine to allocate a local index for the
|
||||||
|
// handshake. It is called at most once, when the first outgoing message that
|
||||||
|
// carries a payload is built.
|
||||||
|
//
|
||||||
|
// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning
|
||||||
|
// "no index assigned" on the wire and in the payload-presence checks. If an
|
||||||
|
// allocator ever returned 0, a legitimate handshake's payload could be
|
||||||
|
// indistinguishable from an empty one and would be rejected.
|
||||||
|
type IndexAllocator func() (uint32, error)
|
||||||
|
|
||||||
|
// CertVerifier is called by the Machine after reconstructing the peer's
|
||||||
|
// certificate from the handshake. The verifier performs all validation
|
||||||
|
// (CA trust, expiry, policy checks, allow lists).
|
||||||
|
type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
|
||||||
|
|
||||||
|
// Result contains the results of a successful handshake.
|
||||||
|
// Returned by ProcessPacket when the handshake is complete.
|
||||||
|
type Result struct {
|
||||||
|
EKey *noise.CipherState
|
||||||
|
DKey *noise.CipherState
|
||||||
|
MyCert cert.Certificate
|
||||||
|
RemoteCert *cert.CachedCertificate
|
||||||
|
RemoteIndex uint32
|
||||||
|
LocalIndex uint32
|
||||||
|
HandshakeTime uint64
|
||||||
|
MessageIndex uint64 // number of messages exchanged during the handshake
|
||||||
|
Initiator bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Machine drives a Noise handshake through N messages. It handles Noise
|
||||||
|
// protocol operations, certificate reconstruction, and payload encoding.
|
||||||
|
// Certificate validation is delegated to the caller via CertVerifier.
|
||||||
|
//
|
||||||
|
// A Machine is not safe for concurrent use. The caller must ensure that
|
||||||
|
// Initiate and ProcessPacket are not called concurrently.
|
||||||
|
//
|
||||||
|
// Error contract: when ProcessPacket or Initiate returns an error, callers
|
||||||
|
// must check Failed() to decide what to do next. If Failed() is false the
|
||||||
|
// underlying noise state was not advanced (the packet was rejected before
|
||||||
|
// ReadMessage took effect, or the rejection is non-fatal like a stale
|
||||||
|
// retransmit) and the Machine can accept another packet. If Failed() is
|
||||||
|
// true the Machine is unrecoverable and the caller must abandon it.
|
||||||
|
type Machine struct {
|
||||||
|
hs *noise.HandshakeState
|
||||||
|
getCred GetCredentialFunc
|
||||||
|
allocIndex IndexAllocator
|
||||||
|
verifier CertVerifier
|
||||||
|
result *Result
|
||||||
|
msgs []msgFlags
|
||||||
|
myVersion cert.Version
|
||||||
|
subtype header.MessageSubType
|
||||||
|
indexAllocated bool
|
||||||
|
remoteCertSet bool
|
||||||
|
payloadSet bool
|
||||||
|
failed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMachine creates a handshake state machine. The subtype determines both
|
||||||
|
// the noise pattern and the per-message content layout. The credential for
|
||||||
|
// `version` is fetched via getCred and used to seed the noise.HandshakeState.
|
||||||
|
// IndexAllocator is called lazily when the first outgoing payload is built.
|
||||||
|
func NewMachine(
|
||||||
|
version cert.Version,
|
||||||
|
getCred GetCredentialFunc,
|
||||||
|
verifier CertVerifier,
|
||||||
|
allocIndex IndexAllocator,
|
||||||
|
initiator bool,
|
||||||
|
subtype header.MessageSubType,
|
||||||
|
) (*Machine, error) {
|
||||||
|
info, err := subtypeInfoFor(subtype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cred := getCred(version)
|
||||||
|
if cred == nil {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrNoCredential, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
hs, err := cred.buildHandshakeState(initiator, info.pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build noise state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Machine{
|
||||||
|
hs: hs,
|
||||||
|
subtype: subtype,
|
||||||
|
msgs: info.msgs,
|
||||||
|
getCred: getCred,
|
||||||
|
allocIndex: allocIndex,
|
||||||
|
verifier: verifier,
|
||||||
|
myVersion: version,
|
||||||
|
result: &Result{
|
||||||
|
Initiator: initiator,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Failed returns true if the Machine is in an unrecoverable state.
|
||||||
|
func (m *Machine) Failed() bool {
|
||||||
|
return m.failed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subtype returns the handshake subtype this Machine was built for.
|
||||||
|
func (m *Machine) Subtype() header.MessageSubType {
|
||||||
|
return m.subtype
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageIndex returns the noise handshake message index, which equals the
|
||||||
|
// wire counter of the most recently sent or received message.
|
||||||
|
func (m *Machine) MessageIndex() int {
|
||||||
|
return m.hs.MessageIndex()
|
||||||
|
}
|
||||||
|
|
||||||
|
// requireComplete checks that both a peer cert and payload have been received.
|
||||||
|
// Marks the machine as failed if not.
|
||||||
|
func (m *Machine) requireComplete() error {
|
||||||
|
if !m.payloadSet || !m.remoteCertSet {
|
||||||
|
m.failed = true
|
||||||
|
return ErrIncompleteHandshake
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// myMsgFlags returns the flags for the current outgoing message.
|
||||||
|
func (m *Machine) myMsgFlags() msgFlags {
|
||||||
|
idx := m.hs.MessageIndex()
|
||||||
|
if idx < len(m.msgs) {
|
||||||
|
return m.msgs[idx]
|
||||||
|
}
|
||||||
|
return msgFlags{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// peerMsgFlags returns the flags for the message we just read.
|
||||||
|
func (m *Machine) peerMsgFlags() msgFlags {
|
||||||
|
idx := m.hs.MessageIndex() - 1
|
||||||
|
if idx >= 0 && idx < len(m.msgs) {
|
||||||
|
return m.msgs[idx]
|
||||||
|
}
|
||||||
|
return msgFlags{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initiate produces the first handshake message. Only valid for initiators,
|
||||||
|
// and must be called exactly once before ProcessPacket.
|
||||||
|
//
|
||||||
|
// out is a destination buffer the message is appended to and returned. Pass
|
||||||
|
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
|
||||||
|
// buf[:0]) with sufficient capacity to avoid allocation.
|
||||||
|
//
|
||||||
|
// An error return may not indicate a fatal condition, check Failed() to
|
||||||
|
// determine if the Machine can still be used.
|
||||||
|
func (m *Machine) Initiate(out []byte) ([]byte, error) {
|
||||||
|
if m.failed {
|
||||||
|
return nil, ErrMachineFailed
|
||||||
|
}
|
||||||
|
if !m.result.Initiator {
|
||||||
|
m.failed = true
|
||||||
|
return nil, ErrInitiateOnResponder
|
||||||
|
}
|
||||||
|
if m.hs.MessageIndex() != 0 {
|
||||||
|
m.failed = true
|
||||||
|
return nil, ErrInitiateAlreadyCalled
|
||||||
|
}
|
||||||
|
|
||||||
|
// At MessageIndex=0 with RemoteIndex still zero, buildResponse produces
|
||||||
|
// header counter 1 and remote index 0, which is what the initial message needs.
|
||||||
|
out, _, _, err := m.buildResponse(out)
|
||||||
|
if err != nil {
|
||||||
|
m.failed = true
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessPacket handles an incoming handshake message. It advances the Noise
|
||||||
|
// state, validates the peer certificate via the verifier, and optionally
|
||||||
|
// produces a response.
|
||||||
|
//
|
||||||
|
// out is a destination buffer the response is appended to and returned. Pass
|
||||||
|
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
|
||||||
|
// buf[:0]) with sufficient capacity to avoid allocation. The returned slice
|
||||||
|
// is nil when no outgoing message is produced (handshake complete on this
|
||||||
|
// side, or final message of a multi-message pattern).
|
||||||
|
//
|
||||||
|
// Returns a non-nil Result when the handshake is complete.
|
||||||
|
// An error return may not indicate a fatal condition, check Failed() to
|
||||||
|
// determine if the Machine can still be used.
|
||||||
|
func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) {
|
||||||
|
if m.failed {
|
||||||
|
return nil, nil, ErrMachineFailed
|
||||||
|
}
|
||||||
|
if len(packet) < header.Len {
|
||||||
|
return nil, nil, ErrPacketTooShort
|
||||||
|
}
|
||||||
|
// Reject packets whose subtype doesn't match the one this Machine was
|
||||||
|
// built for. A pending handshake that suddenly receives a different
|
||||||
|
// subtype on its index is either a stray packet that matched by chance
|
||||||
|
// or a peer protocol violation; drop it without failing the Machine so
|
||||||
|
// the legitimate retransmit can still complete.
|
||||||
|
if header.MessageSubType(packet[1]) != m.subtype {
|
||||||
|
return nil, nil, ErrSubtypeMismatch
|
||||||
|
}
|
||||||
|
if m.result.Initiator && m.hs.MessageIndex() == 0 {
|
||||||
|
m.failed = true
|
||||||
|
return nil, nil, ErrInitiateNotCalled
|
||||||
|
}
|
||||||
|
|
||||||
|
// The (eKey, dKey) ordering here is correct for IX, where the initiator
|
||||||
|
// completes the handshake by reading the responder's stage-2 message.
|
||||||
|
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
|
||||||
|
// For 3-message patterns where a responder finishes by reading the final
|
||||||
|
// message, this ordering would be wrong; revisit when XX/pqIX lands.
|
||||||
|
msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:])
|
||||||
|
if err != nil {
|
||||||
|
// Noise ReadMessage failed. The noise library checkpoints and rolls back
|
||||||
|
// on failure, so the Machine is still alive. The caller can retry with
|
||||||
|
// a different packet.
|
||||||
|
return nil, nil, fmt.Errorf("noise ReadMessage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// From here on, noise state has advanced. Any error is fatal.
|
||||||
|
flags := m.peerMsgFlags()
|
||||||
|
|
||||||
|
if err := m.processPayload(msg, flags); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If ReadMessage derived keys, the handshake is complete. Noise should
|
||||||
|
// always produce both keys together; asymmetry is a protocol invariant
|
||||||
|
// violation.
|
||||||
|
if eKey != nil || dKey != nil {
|
||||||
|
if eKey == nil || dKey == nil {
|
||||||
|
m.failed = true
|
||||||
|
return nil, nil, ErrAsymmetricCipherKeys
|
||||||
|
}
|
||||||
|
if err := m.requireComplete(); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return nil, m.completed(eKey, dKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadMessage didn't complete, produce the next outgoing message
|
||||||
|
out, dk, ek, err := m.buildResponse(out)
|
||||||
|
if err != nil {
|
||||||
|
m.failed = true
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ek != nil || dk != nil {
|
||||||
|
if ek == nil || dk == nil {
|
||||||
|
m.failed = true
|
||||||
|
return nil, nil, ErrAsymmetricCipherKeys
|
||||||
|
}
|
||||||
|
if err := m.requireComplete(); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return out, m.completed(ek, dk), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result {
|
||||||
|
m.result.EKey = eKey
|
||||||
|
m.result.DKey = dKey
|
||||||
|
m.result.MessageIndex = uint64(m.hs.MessageIndex())
|
||||||
|
return m.result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
|
||||||
|
if len(msg) == 0 {
|
||||||
|
if flags.expectsPayload || flags.expectsCert {
|
||||||
|
m.failed = true
|
||||||
|
return ErrMissingContent
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := UnmarshalPayload(msg)
|
||||||
|
if err != nil {
|
||||||
|
m.failed = true
|
||||||
|
return fmt.Errorf("unmarshal handshake: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert the payload contains exactly what we expect
|
||||||
|
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0
|
||||||
|
if hasPayloadData != flags.expectsPayload {
|
||||||
|
m.failed = true
|
||||||
|
return ErrUnexpectedContent
|
||||||
|
}
|
||||||
|
|
||||||
|
hasCertData := len(payload.Cert) > 0
|
||||||
|
if hasCertData != flags.expectsCert {
|
||||||
|
m.failed = true
|
||||||
|
return ErrUnexpectedContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process payload
|
||||||
|
if flags.expectsPayload {
|
||||||
|
if m.result.Initiator {
|
||||||
|
m.result.RemoteIndex = payload.ResponderIndex
|
||||||
|
} else {
|
||||||
|
m.result.RemoteIndex = payload.InitiatorIndex
|
||||||
|
}
|
||||||
|
m.result.HandshakeTime = payload.Time
|
||||||
|
m.payloadSet = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process certificate
|
||||||
|
if flags.expectsCert {
|
||||||
|
if err := m.validateCert(payload); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Machine) validateCert(payload Payload) error {
|
||||||
|
cred := m.getCred(m.myVersion)
|
||||||
|
if cred == nil {
|
||||||
|
m.failed = true
|
||||||
|
return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
|
||||||
|
}
|
||||||
|
rc, err := cert.Recombine(
|
||||||
|
cert.Version(payload.CertVersion),
|
||||||
|
payload.Cert,
|
||||||
|
m.hs.PeerStatic(),
|
||||||
|
cred.Cert.Curve(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
m.failed = true
|
||||||
|
return fmt.Errorf("recombine cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) {
|
||||||
|
m.failed = true
|
||||||
|
return ErrPublicKeyMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version negotiation, if the peer sent a different version and we have it, switch
|
||||||
|
if rc.Version() != m.myVersion {
|
||||||
|
if m.getCred(rc.Version()) != nil {
|
||||||
|
m.myVersion = rc.Version()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
verified, err := m.verifier(rc)
|
||||||
|
if err != nil {
|
||||||
|
m.failed = true
|
||||||
|
return fmt.Errorf("verify cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.result.RemoteCert = verified
|
||||||
|
m.remoteCertSet = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
|
||||||
|
if !flags.expectsPayload && !flags.expectsCert {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var p Payload
|
||||||
|
if flags.expectsPayload {
|
||||||
|
if !m.indexAllocated {
|
||||||
|
index, err := m.allocIndex()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err)
|
||||||
|
}
|
||||||
|
m.result.LocalIndex = index
|
||||||
|
m.indexAllocated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.result.Initiator {
|
||||||
|
p.InitiatorIndex = m.result.LocalIndex
|
||||||
|
} else {
|
||||||
|
p.ResponderIndex = m.result.LocalIndex
|
||||||
|
p.InitiatorIndex = m.result.RemoteIndex
|
||||||
|
}
|
||||||
|
p.Time = uint64(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
if flags.expectsCert {
|
||||||
|
cred := m.getCred(m.myVersion)
|
||||||
|
if cred == nil {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
|
||||||
|
}
|
||||||
|
p.Cert = cred.Bytes
|
||||||
|
p.CertVersion = uint32(cred.Cert.Version())
|
||||||
|
m.result.MyCert = cred.Cert
|
||||||
|
}
|
||||||
|
|
||||||
|
return MarshalPayload(nil, p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
|
||||||
|
flags := m.myMsgFlags()
|
||||||
|
hsBytes, err := m.marshalOutgoing(flags)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend out by header.Len to make room for the header. slices.Grow is a
|
||||||
|
// no-op when the cap is already sufficient (the zero-copy case where the
|
||||||
|
// caller passed a pre-sized buffer). header.Encode overwrites the new
|
||||||
|
// bytes, so they don't need to be zeroed.
|
||||||
|
start := len(out)
|
||||||
|
out = slices.Grow(out, header.Len)[:start+header.Len]
|
||||||
|
header.Encode(
|
||||||
|
out[start:],
|
||||||
|
header.Version, header.Handshake, m.subtype,
|
||||||
|
m.result.RemoteIndex,
|
||||||
|
uint64(m.hs.MessageIndex()+1),
|
||||||
|
)
|
||||||
|
|
||||||
|
// noise.WriteMessage appends the encrypted handshake message to out,
|
||||||
|
// reusing capacity when present.
|
||||||
|
//
|
||||||
|
// The (dKey, eKey) ordering here is correct for IX, where the responder
|
||||||
|
// completes the handshake by writing the stage-2 message. noise returns
|
||||||
|
// (cs1, cs2) where cs1 is the initiator->responder cipher (which is the
|
||||||
|
// responder's decrypt key). For 3-message patterns where an initiator
|
||||||
|
// finishes by writing the final message, this ordering would be wrong;
|
||||||
|
// revisit when XX/pqIX lands.
|
||||||
|
out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, dKey, eKey, nil
|
||||||
|
}
|
||||||
662
handshake/machine_test.go
Normal file
662
handshake/machine_test.go
Normal file
@@ -0,0 +1,662 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/flynn/noise"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
ct "github.com/slackhq/nebula/cert_test"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMachineIXHappyPath(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
|
||||||
|
initCS := newTestCertState(t, ca, caKey, "initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
respCS := newTestCertState(t, ca, caKey, "responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||||
|
|
||||||
|
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
|
||||||
|
|
||||||
|
assert.Equal(t, "responder", initR.RemoteCert.Certificate.Name())
|
||||||
|
assert.Equal(t, "initiator", respR.RemoteCert.Certificate.Name())
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(1000), initR.LocalIndex)
|
||||||
|
assert.Equal(t, uint32(2000), initR.RemoteIndex)
|
||||||
|
assert.Equal(t, uint32(2000), respR.LocalIndex)
|
||||||
|
assert.Equal(t, uint32(1000), respR.RemoteIndex)
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(2), initR.MessageIndex, "IX has 2 messages")
|
||||||
|
assert.Equal(t, uint64(2), respR.MessageIndex, "IX has 2 messages")
|
||||||
|
|
||||||
|
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("hello"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("hello"), pt1)
|
||||||
|
|
||||||
|
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("world"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("world"), pt2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachineInitiateErrors(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
|
||||||
|
t.Run("initiate on responder", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
_, err := m.Initiate(nil)
|
||||||
|
require.ErrorIs(t, err, ErrInitiateOnResponder)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("initiate called twice", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, true, 100)
|
||||||
|
_, err := m.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = m.Initiate(nil)
|
||||||
|
require.ErrorIs(t, err, ErrInitiateAlreadyCalled)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("process packet before initiate on initiator", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, true, 100)
|
||||||
|
_, _, err := m.ProcessPacket(nil, make([]byte, 100))
|
||||||
|
require.ErrorIs(t, err, ErrInitiateNotCalled)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("calling failed machine", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
_, err := m.Initiate(nil) // fails: responder
|
||||||
|
require.Error(t, err)
|
||||||
|
_, err = m.Initiate(nil) // fails: already failed
|
||||||
|
require.ErrorIs(t, err, ErrMachineFailed)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachineProcessPacketErrors(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
|
||||||
|
t.Run("packet too short", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
_, _, err := m.ProcessPacket(nil, []byte{1, 2, 3})
|
||||||
|
require.ErrorIs(t, err, ErrPacketTooShort)
|
||||||
|
assert.False(t, m.Failed(), "short packet should not kill machine")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("noise decryption failure is recoverable", func(t *testing.T) {
|
||||||
|
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
initM := newTestMachine(t, initCS, v, true, 100)
|
||||||
|
msg1, err := initM.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
respM := newTestMachine(t, cs, v, false, 200)
|
||||||
|
resp, _, err := respM.ProcessPacket(nil, msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
corrupted := make([]byte, len(resp))
|
||||||
|
copy(corrupted, resp)
|
||||||
|
for i := header.Len; i < len(corrupted); i++ {
|
||||||
|
corrupted[i] ^= 0xff
|
||||||
|
}
|
||||||
|
_, _, err = initM.ProcessPacket(nil, corrupted)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.False(t, initM.Failed(), "noise failure should be recoverable")
|
||||||
|
|
||||||
|
// And the machine should still complete a real handshake afterward.
|
||||||
|
_, result, err := initM.ProcessPacket(nil, resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result, "initiator should complete on the legitimate response")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid cert is fatal", func(t *testing.T) {
|
||||||
|
otherCA, _, otherCAKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
otherCS := newTestCertState(t, otherCA, otherCAKey, "other", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||||
|
|
||||||
|
initM := newTestMachine(t, otherCS, testVerifier(ct.NewTestCAPool(otherCA)), true, 100)
|
||||||
|
msg1, err := initM.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
respM := newTestMachine(t, cs, v, false, 200)
|
||||||
|
_, _, err = respM.ProcessPacket(nil, msg1)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, respM.Failed(), "cert validation failure should kill machine")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("subtype mismatch is recoverable", func(t *testing.T) {
|
||||||
|
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
initM := newTestMachine(t, initCS, v, true, 100)
|
||||||
|
msg1, err := initM.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Mutate the subtype byte (offset 1 in the header) to a value the
|
||||||
|
// responder Machine wasn't built for.
|
||||||
|
bad := make([]byte, len(msg1))
|
||||||
|
copy(bad, msg1)
|
||||||
|
bad[1] = 0xff
|
||||||
|
|
||||||
|
respM := newTestMachine(t, cs, v, false, 200)
|
||||||
|
_, _, err = respM.ProcessPacket(nil, bad)
|
||||||
|
require.ErrorIs(t, err, ErrSubtypeMismatch)
|
||||||
|
assert.False(t, respM.Failed(), "subtype mismatch should not kill the machine")
|
||||||
|
|
||||||
|
// And the machine should still complete a real handshake afterward.
|
||||||
|
resp, result, err := respM.ProcessPacket(nil, msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result, "responder should complete on the legitimate stage-1 packet")
|
||||||
|
assert.NotEmpty(t, resp, "responder should produce a stage-2 reply")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMachineProcessPayload exercises processPayload's internal validation
|
||||||
|
// directly. Most of these failure modes can't be reached black-box once the
|
||||||
|
// subtype check at the top of ProcessPacket gates external callers, so we
|
||||||
|
// drive them by hand here for coverage.
|
||||||
|
func TestMachineProcessPayload(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
|
||||||
|
t.Run("empty message with expects fails", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
err := m.processPayload(nil, msgFlags{expectsPayload: true, expectsCert: true})
|
||||||
|
require.ErrorIs(t, err, ErrMissingContent)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty message with no expects passes", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
err := m.processPayload(nil, msgFlags{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("malformed protobuf is fatal", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
err := m.processPayload([]byte{0xff, 0xff, 0xff}, msgFlags{expectsPayload: true, expectsCert: true})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unexpected payload data is fatal", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
// A payload with index data when none was expected.
|
||||||
|
bytes := MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1})
|
||||||
|
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
|
||||||
|
require.ErrorIs(t, err, ErrUnexpectedContent)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unexpected cert data is fatal", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
// A payload with cert when none was expected.
|
||||||
|
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
|
||||||
|
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
|
||||||
|
require.ErrorIs(t, err, ErrUnexpectedContent)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing payload data when expected is fatal", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
// Cert present, but no index/time fields.
|
||||||
|
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
|
||||||
|
err := m.processPayload(bytes, msgFlags{expectsPayload: true, expectsCert: true})
|
||||||
|
require.ErrorIs(t, err, ErrUnexpectedContent)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMachineRequireComplete checks the fail-on-incomplete-handshake path
|
||||||
|
// directly. Like processPayload above this isn't reachable from a normal IX
|
||||||
|
// flow, so we drive it by hand.
|
||||||
|
func TestMachineRequireComplete(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
|
||||||
|
t.Run("missing both fails", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
err := m.requireComplete()
|
||||||
|
require.ErrorIs(t, err, ErrIncompleteHandshake)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("payload only fails", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
m.payloadSet = true
|
||||||
|
err := m.requireComplete()
|
||||||
|
require.ErrorIs(t, err, ErrIncompleteHandshake)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cert only fails", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
m.remoteCertSet = true
|
||||||
|
err := m.requireComplete()
|
||||||
|
require.ErrorIs(t, err, ErrIncompleteHandshake)
|
||||||
|
assert.True(t, m.Failed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("both set passes", func(t *testing.T) {
|
||||||
|
m := newTestMachine(t, cs, v, false, 100)
|
||||||
|
m.payloadSet = true
|
||||||
|
m.remoteCertSet = true
|
||||||
|
err := m.requireComplete()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, m.Failed())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachineAESCipher(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
|
||||||
|
initCS := newTestCertStateWithCipher(
|
||||||
|
t, ca, caKey, "init",
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||||
|
noiseutil.CipherAESGCM,
|
||||||
|
)
|
||||||
|
respCS := newTestCertStateWithCipher(
|
||||||
|
t, ca, caKey, "resp",
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||||
|
noiseutil.CipherAESGCM,
|
||||||
|
)
|
||||||
|
|
||||||
|
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
|
||||||
|
|
||||||
|
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("works"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("works"), pt1)
|
||||||
|
|
||||||
|
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("back"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("back"), pt2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResultFields(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||||
|
|
||||||
|
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
|
||||||
|
|
||||||
|
assert.True(t, initR.Initiator)
|
||||||
|
assert.False(t, respR.Initiator)
|
||||||
|
assert.NotZero(t, initR.HandshakeTime)
|
||||||
|
assert.NotZero(t, respR.HandshakeTime)
|
||||||
|
assert.NotNil(t, initR.RemoteCert)
|
||||||
|
assert.NotNil(t, respR.RemoteCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachineBufferReuse(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
|
||||||
|
initM := newTestMachine(t, initCS, v, true, 1000)
|
||||||
|
respM := newTestMachine(t, respCS, v, false, 2000)
|
||||||
|
|
||||||
|
msg1, err := initM.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("response writes into provided buffer", func(t *testing.T) {
|
||||||
|
buf := make([]byte, 0, 4096)
|
||||||
|
resp, result, err := respM.ProcessPacket(buf, msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.NotEmpty(t, resp, "response should have content")
|
||||||
|
assert.Equal(t, &buf[:1][0], &resp[:1][0],
|
||||||
|
"response should reuse the provided buffer's backing array")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("initiate writes into provided buffer", func(t *testing.T) {
|
||||||
|
initM2 := newTestMachine(t, initCS, v, true, 3000)
|
||||||
|
buf := make([]byte, 0, 4096)
|
||||||
|
msg, err := initM2.Initiate(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotEmpty(t, msg, "initiate should have content")
|
||||||
|
assert.Equal(t, &buf[:1][0], &msg[:1][0],
|
||||||
|
"initiate should reuse the provided buffer's backing array")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil out still works", func(t *testing.T) {
|
||||||
|
initM2 := newTestMachine(t, initCS, v, true, 4000)
|
||||||
|
respM2 := newTestMachine(t, respCS, v, false, 5000)
|
||||||
|
|
||||||
|
msg1, err := initM2.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, _, err := respM2.ProcessPacket(nil, msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
out, result, err := initM2.ProcessPacket(nil, resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Nil(t, out, "initiator should have no response for IX msg2")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachineMsgIndexTracking(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
|
||||||
|
initM := newTestMachine(t, initCS, v, true, 100)
|
||||||
|
respM := newTestMachine(t, respCS, v, false, 200)
|
||||||
|
|
||||||
|
msg1, err := initM.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp1, result1, err := respM.ProcessPacket(nil, msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result1)
|
||||||
|
|
||||||
|
_, result2, err := initM.ProcessPacket(nil, resp1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachineThreeMessagePattern(t *testing.T) {
|
||||||
|
registerTestXXInfo(t)
|
||||||
|
|
||||||
|
// Use HandshakeXX (3 messages) to verify the Machine handles multi-message
|
||||||
|
// patterns correctly. XX flow:
|
||||||
|
// msg1 (I->R): [E] - payload only, no cert
|
||||||
|
// msg2 (R->I): [E, ee, S, es] - payload + cert
|
||||||
|
// msg3 (I->R): [S, se] - cert only (no payload, not first two)
|
||||||
|
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
|
||||||
|
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||||
|
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||||
|
|
||||||
|
initM, err := NewMachine(
|
||||||
|
cert.Version2,
|
||||||
|
initCS.getCredential, v,
|
||||||
|
func() (uint32, error) { return 1000, nil },
|
||||||
|
true, header.HandshakeXXPSK0,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
respM, err := NewMachine(
|
||||||
|
cert.Version2,
|
||||||
|
respCS.getCredential, v,
|
||||||
|
func() (uint32, error) { return 2000, nil },
|
||||||
|
false, header.HandshakeXXPSK0,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// msg1: initiator -> responder (E only, no cert)
|
||||||
|
msg1, err := initM.Initiate(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, msg1)
|
||||||
|
|
||||||
|
// Responder processes msg1, should not complete yet, should produce msg2
|
||||||
|
msg2, result, err := respM.ProcessPacket(nil, msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, result, "XX should not complete on msg1")
|
||||||
|
assert.NotEmpty(t, msg2, "responder should produce msg2")
|
||||||
|
|
||||||
|
// Initiator processes msg2: gets responder's cert, produces msg3, and
|
||||||
|
// completes (WriteMessage for msg3 derives keys)
|
||||||
|
msg3, initResult, err := initM.ProcessPacket(nil, msg2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, initResult, "XX initiator should complete after reading msg2 and writing msg3")
|
||||||
|
assert.NotEmpty(t, msg3, "initiator should produce msg3")
|
||||||
|
assert.Equal(t, "resp", initResult.RemoteCert.Certificate.Name())
|
||||||
|
|
||||||
|
// Responder processes msg3: gets initiator's cert and completes
|
||||||
|
_, respResult, err := respM.ProcessPacket(nil, msg3)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, respResult, "XX responder should complete on msg3")
|
||||||
|
assert.Equal(t, "init", respResult.RemoteCert.Certificate.Name())
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(3), initResult.MessageIndex, "XX has 3 messages")
|
||||||
|
assert.Equal(t, uint64(3), respResult.MessageIndex, "XX has 3 messages")
|
||||||
|
|
||||||
|
// Verify keys work
|
||||||
|
ct1, err := initResult.EKey.Encrypt(nil, nil, []byte("three messages"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
pt1, err := respResult.DKey.Decrypt(nil, nil, ct1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("three messages"), pt1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: ErrIncompleteHandshake is tested implicitly. It can't be triggered with
|
||||||
|
// IX since the cert is always in the payload. A 3-message pattern test (HybridIX)
|
||||||
|
// should exercise the case where cert arrives in msg3 and verify that completing
|
||||||
|
// without it fails.
|
||||||
|
|
||||||
|
func TestMachineExpiredCert(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519,
|
||||||
|
time.Now().Add(-24*time.Hour), time.Now().Add(24*time.Hour),
|
||||||
|
nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
|
||||||
|
expCert, _, expKeyPEM, _ := ct.NewTestCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
|
||||||
|
"expired", time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour),
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, nil, nil,
|
||||||
|
)
|
||||||
|
expKey, _, _, err := cert.UnmarshalPrivateKeyFromPEM(expKeyPEM)
|
||||||
|
require.NoError(t, err)
|
||||||
|
expHsBytes, err := expCert.MarshalForHandshakes()
|
||||||
|
require.NoError(t, err)
|
||||||
|
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||||
|
|
||||||
|
expiredCS := &testCertState{
|
||||||
|
version: cert.Version2,
|
||||||
|
creds: map[cert.Version]*Credential{
|
||||||
|
cert.Version2: NewCredential(expCert, expHsBytes, expKey, ncs),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
respCS := newTestCertState(
|
||||||
|
t, ca, caKey, "responder",
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||||
|
)
|
||||||
|
|
||||||
|
_, respM, _, _, err := initiateHandshake(
|
||||||
|
t, expiredCS, testVerifier(caPool),
|
||||||
|
respCS, testVerifier(caPool),
|
||||||
|
)
|
||||||
|
require.ErrorContains(t, err, "verify cert")
|
||||||
|
assert.True(t, respM.Failed())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachineNoCertNetworks(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca)
|
||||||
|
|
||||||
|
caHsBytes, err := ca.MarshalForHandshakes()
|
||||||
|
require.NoError(t, err)
|
||||||
|
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||||
|
|
||||||
|
noNetCS := &testCertState{
|
||||||
|
version: cert.Version2,
|
||||||
|
creds: map[cert.Version]*Credential{
|
||||||
|
cert.Version2: NewCredential(ca, caHsBytes, caKey, ncs),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
respCS := newTestCertState(
|
||||||
|
t, ca, caKey, "responder",
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||||
|
)
|
||||||
|
|
||||||
|
_, respM, _, _, err := initiateHandshake(
|
||||||
|
t, noNetCS, testVerifier(caPool),
|
||||||
|
respCS, testVerifier(caPool),
|
||||||
|
)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, respM.Failed())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachineDifferentCAs(t *testing.T) {
|
||||||
|
ca1, _, caKey1, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
ca2, _, caKey2, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
initCS := newTestCertState(
|
||||||
|
t, ca1, caKey1, "init",
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||||
|
)
|
||||||
|
respCS := newTestCertState(
|
||||||
|
t, ca2, caKey2, "resp",
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||||
|
)
|
||||||
|
|
||||||
|
_, respM, _, _, err := initiateHandshake(
|
||||||
|
t, initCS, testVerifier(ct.NewTestCAPool(ca1)),
|
||||||
|
respCS, testVerifier(ct.NewTestCAPool(ca2)),
|
||||||
|
)
|
||||||
|
require.ErrorContains(t, err, "verify cert")
|
||||||
|
assert.True(t, respM.Failed())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMachineVersionNegotiation(t *testing.T) {
|
||||||
|
ca1, _, caKey1, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version1, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
ca2, _, caKey2, _ := ct.NewTestCaCert(
|
||||||
|
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||||
|
)
|
||||||
|
caPool := ct.NewTestCAPool(ca1, ca2)
|
||||||
|
|
||||||
|
makeMultiVersionResp := func(t *testing.T) *testCertState {
|
||||||
|
t.Helper()
|
||||||
|
respCertV1, _, respKeyPEM, _ := ct.NewTestCert(
|
||||||
|
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
|
||||||
|
ca1.NotBefore(), ca1.NotAfter(),
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
|
||||||
|
)
|
||||||
|
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
|
||||||
|
respCertV2, _ := ct.NewTestCertDifferentVersion(respCertV1, cert.Version2, ca2, caKey2)
|
||||||
|
respHsV1, _ := respCertV1.MarshalForHandshakes()
|
||||||
|
respHsV2, _ := respCertV2.MarshalForHandshakes()
|
||||||
|
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||||
|
return &testCertState{
|
||||||
|
version: cert.Version1,
|
||||||
|
creds: map[cert.Version]*Credential{
|
||||||
|
cert.Version1: NewCredential(respCertV1, respHsV1, respKey, ncs),
|
||||||
|
cert.Version2: NewCredential(respCertV2, respHsV2, respKey, ncs),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("responder matches initiator version", func(t *testing.T) {
|
||||||
|
initCS := newTestCertState(
|
||||||
|
t, ca2, caKey2, "init",
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||||
|
)
|
||||||
|
respCS := makeMultiVersionResp(t)
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
|
||||||
|
initM, _, respResult, resp, err := initiateHandshake(
|
||||||
|
t, initCS, v,
|
||||||
|
respCS, v,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, respResult)
|
||||||
|
|
||||||
|
assert.Equal(t, cert.Version2, respResult.MyCert.Version(),
|
||||||
|
"responder should negotiate to initiator's version")
|
||||||
|
|
||||||
|
_, initResult, err := initM.ProcessPacket(nil, resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, initResult)
|
||||||
|
assert.Equal(t, cert.Version2, initResult.RemoteCert.Certificate.Version(),
|
||||||
|
"initiator should see V2 cert from responder")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("responder keeps version when no match available", func(t *testing.T) {
|
||||||
|
initCS := newTestCertState(
|
||||||
|
t, ca2, caKey2, "init",
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||||
|
)
|
||||||
|
|
||||||
|
respCert, _, respKeyPEM, _ := ct.NewTestCert(
|
||||||
|
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
|
||||||
|
ca1.NotBefore(), ca1.NotAfter(),
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
|
||||||
|
)
|
||||||
|
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
|
||||||
|
respHs, _ := respCert.MarshalForHandshakes()
|
||||||
|
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||||
|
respCS := &testCertState{
|
||||||
|
version: cert.Version1,
|
||||||
|
creds: map[cert.Version]*Credential{
|
||||||
|
cert.Version1: NewCredential(respCert, respHs, respKey, ncs),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
v := testVerifier(caPool)
|
||||||
|
_, _, respResult, _, err := initiateHandshake(
|
||||||
|
t, initCS, v,
|
||||||
|
respCS, v,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, respResult)
|
||||||
|
|
||||||
|
assert.Equal(t, cert.Version1, respResult.MyCert.Version(),
|
||||||
|
"responder should keep V1 when V2 not available")
|
||||||
|
})
|
||||||
|
}
|
||||||
54
handshake/patterns.go
Normal file
54
handshake/patterns.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/flynn/noise"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
)
|
||||||
|
|
||||||
|
// msgFlags tracks what application data a handshake message carries.
|
||||||
|
type msgFlags struct {
|
||||||
|
expectsPayload bool // message carries indexes and time
|
||||||
|
expectsCert bool // message carries the certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
// subtypeInfo bundles the noise pattern with the per-message flags for a
|
||||||
|
// given handshake subtype.
|
||||||
|
type subtypeInfo struct {
|
||||||
|
pattern noise.HandshakePattern
|
||||||
|
msgs []msgFlags
|
||||||
|
}
|
||||||
|
|
||||||
|
// subtypeInfos defines the noise pattern and message content layout for each
|
||||||
|
// handshake subtype.
|
||||||
|
var subtypeInfos = map[header.MessageSubType]subtypeInfo{
|
||||||
|
// IX: 2 messages, both carry payload and cert
|
||||||
|
header.HandshakeIXPSK0: {
|
||||||
|
pattern: noise.HandshakeIX,
|
||||||
|
msgs: []msgFlags{
|
||||||
|
{expectsPayload: true, expectsCert: true},
|
||||||
|
{expectsPayload: true, expectsCert: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// XX: 3 messages
|
||||||
|
// msg1 (I->R): payload only
|
||||||
|
// msg2 (R->I): payload + cert
|
||||||
|
// msg3 (I->R): cert only
|
||||||
|
//header.HandshakeXXPSK0: {
|
||||||
|
// pattern: noise.HandshakeXX,
|
||||||
|
// msgs: []msgFlags{
|
||||||
|
// {expectsPayload: true, expectsCert: false},
|
||||||
|
// {expectsPayload: true, expectsCert: true},
|
||||||
|
// {expectsPayload: false, expectsCert: true},
|
||||||
|
// },
|
||||||
|
//},
|
||||||
|
}
|
||||||
|
|
||||||
|
func subtypeInfoFor(subtype header.MessageSubType) (subtypeInfo, error) {
|
||||||
|
if info, ok := subtypeInfos[subtype]; ok {
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
return subtypeInfo{}, fmt.Errorf("%w: %d", ErrUnknownSubtype, subtype)
|
||||||
|
}
|
||||||
63
handshake/patterns_test.go
Normal file
63
handshake/patterns_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/flynn/noise"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSubtypeInfo(t *testing.T) {
|
||||||
|
t.Run("IX", func(t *testing.T) {
|
||||||
|
info, err := subtypeInfoFor(header.HandshakeIXPSK0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name)
|
||||||
|
require.Len(t, info.msgs, 2)
|
||||||
|
// msg1: payload + cert
|
||||||
|
assert.True(t, info.msgs[0].expectsPayload)
|
||||||
|
assert.True(t, info.msgs[0].expectsCert)
|
||||||
|
// msg2: payload + cert
|
||||||
|
assert.True(t, info.msgs[1].expectsPayload)
|
||||||
|
assert.True(t, info.msgs[1].expectsCert)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("XX", func(t *testing.T) {
|
||||||
|
registerTestXXInfo(t)
|
||||||
|
info, err := subtypeInfoFor(header.HandshakeXXPSK0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name)
|
||||||
|
require.Len(t, info.msgs, 3)
|
||||||
|
// msg1: payload only
|
||||||
|
assert.True(t, info.msgs[0].expectsPayload)
|
||||||
|
assert.False(t, info.msgs[0].expectsCert)
|
||||||
|
// msg2: payload + cert
|
||||||
|
assert.True(t, info.msgs[1].expectsPayload)
|
||||||
|
assert.True(t, info.msgs[1].expectsCert)
|
||||||
|
// msg3: cert only
|
||||||
|
assert.False(t, info.msgs[2].expectsPayload)
|
||||||
|
assert.True(t, info.msgs[2].expectsCert)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown subtype returns error", func(t *testing.T) {
|
||||||
|
_, err := subtypeInfoFor(99)
|
||||||
|
require.ErrorIs(t, err, ErrUnknownSubtype)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerTestXXInfo temporarily registers XX subtype info for testing.
|
||||||
|
func registerTestXXInfo(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{
|
||||||
|
pattern: noise.HandshakeXX,
|
||||||
|
msgs: []msgFlags{
|
||||||
|
{expectsPayload: true, expectsCert: false},
|
||||||
|
{expectsPayload: true, expectsCert: true},
|
||||||
|
{expectsPayload: false, expectsCert: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
delete(subtypeInfos, header.HandshakeXXPSK0)
|
||||||
|
})
|
||||||
|
}
|
||||||
173
handshake/payload.go
Normal file
173
handshake/payload.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/encoding/protowire"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errInvalidHandshakeMessage = errors.New("invalid handshake message")
|
||||||
|
errInvalidHandshakeDetails = errors.New("invalid handshake details")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Payload represents the decoded fields of a handshake message.
|
||||||
|
// Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
|
||||||
|
type Payload struct {
|
||||||
|
Cert []byte
|
||||||
|
InitiatorIndex uint32
|
||||||
|
ResponderIndex uint32
|
||||||
|
Time uint64
|
||||||
|
CertVersion uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proto field numbers for NebulaHandshakeDetails
|
||||||
|
const (
|
||||||
|
fieldCert = 1 // bytes
|
||||||
|
fieldInitiatorIndex = 2 // uint32
|
||||||
|
fieldResponderIndex = 3 // uint32
|
||||||
|
fieldTime = 5 // uint64
|
||||||
|
fieldCertVersion = 8 // uint32
|
||||||
|
)
|
||||||
|
|
||||||
|
// MarshalPayload encodes a handshake payload in protobuf wire format compatible
|
||||||
|
// with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
|
||||||
|
// Returns out (which may be nil), with the marshalled Payload appended to it.
|
||||||
|
func MarshalPayload(out []byte, p Payload) []byte {
|
||||||
|
var details []byte
|
||||||
|
|
||||||
|
if len(p.Cert) > 0 {
|
||||||
|
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
|
||||||
|
details = protowire.AppendBytes(details, p.Cert)
|
||||||
|
}
|
||||||
|
if p.InitiatorIndex != 0 {
|
||||||
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, uint64(p.InitiatorIndex))
|
||||||
|
}
|
||||||
|
if p.ResponderIndex != 0 {
|
||||||
|
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, uint64(p.ResponderIndex))
|
||||||
|
}
|
||||||
|
if p.Time != 0 {
|
||||||
|
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, p.Time)
|
||||||
|
}
|
||||||
|
if p.CertVersion != 0 {
|
||||||
|
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, uint64(p.CertVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
out = protowire.AppendTag(out, 1, protowire.BytesType)
|
||||||
|
out = protowire.AppendBytes(out, details)
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message.
|
||||||
|
func UnmarshalPayload(b []byte) (Payload, error) {
|
||||||
|
var p Payload
|
||||||
|
|
||||||
|
for len(b) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(b)
|
||||||
|
if n < 0 {
|
||||||
|
return p, errInvalidHandshakeMessage
|
||||||
|
}
|
||||||
|
b = b[n:]
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case num == 1 && typ == protowire.BytesType:
|
||||||
|
details, n := protowire.ConsumeBytes(b)
|
||||||
|
if n < 0 {
|
||||||
|
return p, errInvalidHandshakeMessage
|
||||||
|
}
|
||||||
|
b = b[n:]
|
||||||
|
if err := unmarshalPayloadDetails(&p, details); err != nil {
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, b)
|
||||||
|
if n < 0 {
|
||||||
|
return p, errInvalidHandshakeMessage
|
||||||
|
}
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalPayloadDetails(p *Payload, b []byte) error {
|
||||||
|
for len(b) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(b)
|
||||||
|
if n < 0 {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
b = b[n:]
|
||||||
|
|
||||||
|
// For known field numbers, reject any non-matching wire type as a
|
||||||
|
// hard error rather than silently skipping. The caller will catch
|
||||||
|
// missing-field cases downstream, but a wire-type mismatch on a tag
|
||||||
|
// we know is a peer protocol violation worth flagging here.
|
||||||
|
// Repeated occurrences of a singular field follow proto3 last-wins.
|
||||||
|
switch num {
|
||||||
|
case fieldCert:
|
||||||
|
if typ != protowire.BytesType {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
v, n := protowire.ConsumeBytes(b)
|
||||||
|
if n < 0 {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
p.Cert = append([]byte(nil), v...)
|
||||||
|
b = b[n:]
|
||||||
|
case fieldInitiatorIndex:
|
||||||
|
if typ != protowire.VarintType {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
v, n := protowire.ConsumeVarint(b)
|
||||||
|
if n < 0 || v > math.MaxUint32 {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
p.InitiatorIndex = uint32(v)
|
||||||
|
b = b[n:]
|
||||||
|
case fieldResponderIndex:
|
||||||
|
if typ != protowire.VarintType {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
v, n := protowire.ConsumeVarint(b)
|
||||||
|
if n < 0 || v > math.MaxUint32 {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
p.ResponderIndex = uint32(v)
|
||||||
|
b = b[n:]
|
||||||
|
case fieldTime:
|
||||||
|
if typ != protowire.VarintType {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
v, n := protowire.ConsumeVarint(b)
|
||||||
|
if n < 0 {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
p.Time = v
|
||||||
|
b = b[n:]
|
||||||
|
case fieldCertVersion:
|
||||||
|
if typ != protowire.VarintType {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
v, n := protowire.ConsumeVarint(b)
|
||||||
|
if n < 0 || v > math.MaxUint32 {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
p.CertVersion = uint32(v)
|
||||||
|
b = b[n:]
|
||||||
|
default:
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, b)
|
||||||
|
if n < 0 {
|
||||||
|
return errInvalidHandshakeDetails
|
||||||
|
}
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
361
handshake/payload_test.go
Normal file
361
handshake/payload_test.go
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/encoding/protowire"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPayloadRoundTrip(t *testing.T) {
|
||||||
|
t.Run("all fields set", func(t *testing.T) {
|
||||||
|
data := MarshalPayload(nil, Payload{
|
||||||
|
Cert: []byte("test-cert-bytes"),
|
||||||
|
CertVersion: 2,
|
||||||
|
InitiatorIndex: 12345,
|
||||||
|
ResponderIndex: 67890,
|
||||||
|
Time: 1234567890,
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := UnmarshalPayload(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("test-cert-bytes"), got.Cert)
|
||||||
|
assert.Equal(t, uint32(12345), got.InitiatorIndex)
|
||||||
|
assert.Equal(t, uint32(67890), got.ResponderIndex)
|
||||||
|
assert.Equal(t, uint64(1234567890), got.Time)
|
||||||
|
assert.Equal(t, uint32(2), got.CertVersion)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("minimal fields", func(t *testing.T) {
|
||||||
|
data := MarshalPayload(nil, Payload{InitiatorIndex: 1})
|
||||||
|
|
||||||
|
got, err := UnmarshalPayload(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(1), got.InitiatorIndex)
|
||||||
|
assert.Equal(t, uint32(0), got.ResponderIndex)
|
||||||
|
assert.Equal(t, uint64(0), got.Time)
|
||||||
|
assert.Nil(t, got.Cert)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty payload", func(t *testing.T) {
|
||||||
|
data := MarshalPayload(nil, Payload{})
|
||||||
|
|
||||||
|
got, err := UnmarshalPayload(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("large cert bytes", func(t *testing.T) {
|
||||||
|
bigCert := make([]byte, 4096)
|
||||||
|
for i := range bigCert {
|
||||||
|
bigCert[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := MarshalPayload(nil, Payload{
|
||||||
|
Cert: bigCert,
|
||||||
|
CertVersion: 2,
|
||||||
|
InitiatorIndex: 999,
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := UnmarshalPayload(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, bigCert, got.Cert)
|
||||||
|
assert.Equal(t, uint32(999), got.InitiatorIndex)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("append to existing buffer", func(t *testing.T) {
|
||||||
|
prefix := []byte("prefix")
|
||||||
|
data := MarshalPayload(prefix, Payload{InitiatorIndex: 42})
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("prefix"), data[:6])
|
||||||
|
|
||||||
|
got, err := UnmarshalPayload(data[6:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPayloadUnknownFields(t *testing.T) {
|
||||||
|
t.Run("unknown field in outer message is skipped", func(t *testing.T) {
|
||||||
|
// Marshal a normal payload then append an unknown field (field 99, varint)
|
||||||
|
data := MarshalPayload(nil, Payload{InitiatorIndex: 42})
|
||||||
|
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
||||||
|
data = protowire.AppendVarint(data, 12345)
|
||||||
|
|
||||||
|
got, err := UnmarshalPayload(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown field in details is skipped", func(t *testing.T) {
|
||||||
|
// Build details with a known field + unknown field
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, 77)
|
||||||
|
// Unknown field 50, varint
|
||||||
|
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, 9999)
|
||||||
|
// Another known field after the unknown one
|
||||||
|
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, 88)
|
||||||
|
|
||||||
|
// Wrap in outer message
|
||||||
|
var data []byte
|
||||||
|
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
||||||
|
data = protowire.AppendBytes(data, details)
|
||||||
|
|
||||||
|
got, err := UnmarshalPayload(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint32(77), got.InitiatorIndex)
|
||||||
|
assert.Equal(t, uint32(88), got.ResponderIndex)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) {
|
||||||
|
// Fields 6 and 7 are reserved in the proto definition
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, 100)
|
||||||
|
details = protowire.AppendTag(details, 6, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, 1)
|
||||||
|
details = protowire.AppendTag(details, 7, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, 2)
|
||||||
|
|
||||||
|
var data []byte
|
||||||
|
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
||||||
|
data = protowire.AppendBytes(data, details)
|
||||||
|
|
||||||
|
got, err := UnmarshalPayload(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint32(100), got.InitiatorIndex)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPayloadBytesConsumed(t *testing.T) {
|
||||||
|
t.Run("all bytes consumed on valid input", func(t *testing.T) {
|
||||||
|
original := Payload{
|
||||||
|
Cert: []byte("cert"),
|
||||||
|
CertVersion: 2,
|
||||||
|
InitiatorIndex: 100,
|
||||||
|
ResponderIndex: 200,
|
||||||
|
Time: 999,
|
||||||
|
}
|
||||||
|
data := MarshalPayload(nil, original)
|
||||||
|
|
||||||
|
got, err := UnmarshalPayload(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Re-marshal and compare — proves we consumed and reproduced all fields
|
||||||
|
remarshaled := MarshalPayload(nil, got)
|
||||||
|
assert.Equal(t, data, remarshaled)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope
|
||||||
|
// so UnmarshalPayload can reach unmarshalPayloadDetails.
|
||||||
|
func wrapDetails(details []byte) []byte {
|
||||||
|
var out []byte
|
||||||
|
out = protowire.AppendTag(out, 1, protowire.BytesType)
|
||||||
|
out = protowire.AppendBytes(out, details)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPayloadUnmarshalErrors(t *testing.T) {
|
||||||
|
t.Run("nil input", func(t *testing.T) {
|
||||||
|
got, err := UnmarshalPayload(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated outer tag", func(t *testing.T) {
|
||||||
|
_, err := UnmarshalPayload([]byte{0x80})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated outer details field", func(t *testing.T) {
|
||||||
|
_, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated outer unknown field", func(t *testing.T) {
|
||||||
|
// Valid tag for unknown field 99 varint, but no value follows
|
||||||
|
var data []byte
|
||||||
|
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
||||||
|
_, err := UnmarshalPayload(data)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated details tag", func(t *testing.T) {
|
||||||
|
_, err := UnmarshalPayload(wrapDetails([]byte{0x80}))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated cert bytes", func(t *testing.T) {
|
||||||
|
// Field 1 (cert), bytes type, length 10 but only 2 bytes
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
|
||||||
|
details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated initiator index varint", func(t *testing.T) {
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||||
|
details = append(details, 0x80) // incomplete varint
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated responder index varint", func(t *testing.T) {
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||||
|
details = append(details, 0x80)
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated time varint", func(t *testing.T) {
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
|
||||||
|
details = append(details, 0x80)
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated cert version varint", func(t *testing.T) {
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||||
|
details = append(details, 0x80)
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated unknown field in details", func(t *testing.T) {
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
||||||
|
details = append(details, 0x80) // incomplete varint
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cert with wrong wire type rejected", func(t *testing.T) {
|
||||||
|
// fieldCert as Varint instead of Bytes.
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldCert, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, 42)
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("initiator index with wrong wire type rejected", func(t *testing.T) {
|
||||||
|
// fieldInitiatorIndex as Bytes instead of Varint.
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType)
|
||||||
|
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("time with wrong wire type rejected", func(t *testing.T) {
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldTime, protowire.BytesType)
|
||||||
|
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cert version with wrong wire type rejected", func(t *testing.T) {
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType)
|
||||||
|
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) {
|
||||||
|
// Per proto3, multiple instances of a singular field are accepted and
|
||||||
|
// the last value wins. We keep this behavior so that peers using
|
||||||
|
// alternative encoders aren't rejected.
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, 1)
|
||||||
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, 42)
|
||||||
|
got, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("initiator index varint overflow rejected", func(t *testing.T) {
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cert version varint overflow rejected", func(t *testing.T) {
|
||||||
|
var details []byte
|
||||||
|
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||||
|
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
||||||
|
_, err := UnmarshalPayload(wrapDetails(details))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it
|
||||||
|
// never panics, and for any input that parses cleanly, that re-marshal +
|
||||||
|
// re-parse is a fix-point. Inputs come from an authenticated peer (post-
|
||||||
|
// noise-decrypt), so the threat model is "valid peer behaving arbitrarily,"
|
||||||
|
// not "unauthenticated injection."
|
||||||
|
func FuzzPayload(f *testing.F) {
|
||||||
|
// Seed corpus with a handful of known-good shapes.
|
||||||
|
f.Add(MarshalPayload(nil, Payload{}))
|
||||||
|
f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}))
|
||||||
|
f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}))
|
||||||
|
f.Add(MarshalPayload(nil, Payload{
|
||||||
|
Cert: []byte("seed-cert"),
|
||||||
|
InitiatorIndex: 1,
|
||||||
|
ResponderIndex: 2,
|
||||||
|
Time: 3,
|
||||||
|
CertVersion: 2,
|
||||||
|
}))
|
||||||
|
f.Add([]byte{})
|
||||||
|
f.Add([]byte{0xff})
|
||||||
|
|
||||||
|
f.Fuzz(func(t *testing.T, data []byte) {
|
||||||
|
p1, err := UnmarshalPayload(data)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For any input that parses, re-marshaling and re-parsing must
|
||||||
|
// yield an equivalent Payload. This catches dispatch bugs (e.g.
|
||||||
|
// emitting a field on marshal that we don't accept on parse) and
|
||||||
|
// any non-idempotent parsing behavior.
|
||||||
|
b2 := MarshalPayload(nil, p1)
|
||||||
|
p2, err := UnmarshalPayload(b2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2)
|
||||||
|
}
|
||||||
|
if !payloadsEqual(p1, p2) {
|
||||||
|
t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func payloadsEqual(a, b Payload) bool {
|
||||||
|
return bytes.Equal(a.Cert, b.Cert) &&
|
||||||
|
a.InitiatorIndex == b.InitiatorIndex &&
|
||||||
|
a.ResponderIndex == b.ResponderIndex &&
|
||||||
|
a.Time == b.Time &&
|
||||||
|
a.CertVersion == b.CertVersion
|
||||||
|
}
|
||||||
746
handshake_ix.go
746
handshake_ix.go
@@ -1,746 +0,0 @@
|
|||||||
package nebula
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
"github.com/slackhq/nebula/header"
|
|
||||||
"github.com/slackhq/nebula/udp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NOISE IX Handshakes
|
|
||||||
|
|
||||||
// This function constructs a handshake packet, but does not actually send it
|
|
||||||
// Sending is done by the handshake manager
|
|
||||||
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|
||||||
err := f.handshakeManager.allocateIndex(hh)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
cs := f.pki.getCertState()
|
|
||||||
v := cs.initiatingVersion
|
|
||||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
|
||||||
v = hh.initiatingVersionOverride
|
|
||||||
} else if v < cert.Version2 {
|
|
||||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
|
||||||
for _, a := range hh.hostinfo.vpnAddrs {
|
|
||||||
if a.Is6() {
|
|
||||||
v = cert.Version2
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
crt := cs.getCertificate(v)
|
|
||||||
if crt == nil {
|
|
||||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
|
||||||
WithField("certVersion", v).
|
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
crtHs := cs.getHandshakeBytes(v)
|
|
||||||
if crtHs == nil {
|
|
||||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
|
||||||
WithField("certVersion", v).
|
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
|
||||||
WithField("certVersion", v).
|
|
||||||
Error("Failed to create connection state")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
hh.hostinfo.ConnectionState = ci
|
|
||||||
|
|
||||||
hs := &NebulaHandshake{
|
|
||||||
Details: &NebulaHandshakeDetails{
|
|
||||||
InitiatorIndex: hh.hostinfo.localIndexId,
|
|
||||||
Time: uint64(time.Now().UnixNano()),
|
|
||||||
Cert: crtHs,
|
|
||||||
CertVersion: uint32(v),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.multiPort.Tx || f.multiPort.Rx {
|
|
||||||
hs.Details.InitiatorMultiPort = &MultiPortDetails{
|
|
||||||
RxSupported: f.multiPort.Rx,
|
|
||||||
TxSupported: f.multiPort.Tx,
|
|
||||||
BasePort: uint32(f.multiPort.TxBasePort),
|
|
||||||
TotalPorts: uint32(f.multiPort.TxPorts),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
|
||||||
WithField("certVersion", v).
|
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
|
|
||||||
|
|
||||||
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// We are sending handshake packet 1, so we don't expect to receive
|
|
||||||
// handshake packet 1 from the responder
|
|
||||||
ci.window.Update(f.l, 1)
|
|
||||||
|
|
||||||
hh.hostinfo.HandshakePacket[0] = msg
|
|
||||||
hh.ready = true
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) {
|
|
||||||
cs := f.pki.getCertState()
|
|
||||||
crt := cs.GetDefaultCertificate()
|
|
||||||
if crt == nil {
|
|
||||||
f.l.WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Error("Failed to create connection state")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark packet 1 as seen so it doesn't show up as missed
|
|
||||||
ci.window.Update(f.l, 1)
|
|
||||||
|
|
||||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Error("Failed to call noise.ReadMessage")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hs := &NebulaHandshake{}
|
|
||||||
err = hs.Unmarshal(msg)
|
|
||||||
if err != nil || hs.Details == nil {
|
|
||||||
f.l.WithError(err).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Error("Failed unmarshal handshake message")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Info("Handshake did not contain a certificate")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
|
||||||
if err != nil {
|
|
||||||
fp, fperr := rc.Fingerprint()
|
|
||||||
if fperr != nil {
|
|
||||||
fp = "<error generating certificate fingerprint>"
|
|
||||||
}
|
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
WithField("certVpnNetworks", rc.Networks()).
|
|
||||||
WithField("certFingerprint", fp)
|
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
e = e.WithField("cert", rc)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.Info("Invalid certificate from host")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
|
||||||
f.l.WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
|
||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
|
||||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
|
||||||
if myCertOtherVersion == nil {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithError(err).WithFields(m{
|
|
||||||
"from": via,
|
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
|
||||||
"cert": remoteCert,
|
|
||||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Record the certificate we are actually using
|
|
||||||
ci.myCert = myCertOtherVersion
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
|
||||||
f.l.WithError(err).WithField("from", via).
|
|
||||||
WithField("cert", remoteCert).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Info("No networks in certificate")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
certName := remoteCert.Certificate.Name()
|
|
||||||
certVersion := remoteCert.Certificate.Version()
|
|
||||||
fingerprint := remoteCert.Fingerprint
|
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
|
||||||
vpnNetworks := remoteCert.Certificate.Networks()
|
|
||||||
|
|
||||||
anyVpnAddrsInCommon := false
|
|
||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
|
||||||
for i, network := range vpnNetworks {
|
|
||||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
|
||||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
vpnAddrs[i] = network.Addr()
|
|
||||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
|
||||||
anyVpnAddrsInCommon = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !via.IsRelayed {
|
|
||||||
// We only want to apply the remote allow list for direct tunnels here
|
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
|
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
|
||||||
Debug("lighthouse.remote_allow_list denied incoming handshake")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
myIndex, err := generateIndex(f.l)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var multiportTx, multiportRx bool
|
|
||||||
if f.multiPort.Rx || f.multiPort.Tx {
|
|
||||||
if hs.Details.InitiatorMultiPort != nil {
|
|
||||||
multiportTx = hs.Details.InitiatorMultiPort.RxSupported && f.multiPort.Tx
|
|
||||||
multiportRx = hs.Details.InitiatorMultiPort.TxSupported && f.multiPort.Rx
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.Details.ResponderMultiPort = &MultiPortDetails{
|
|
||||||
TxSupported: f.multiPort.Tx,
|
|
||||||
RxSupported: f.multiPort.Rx,
|
|
||||||
BasePort: uint32(f.multiPort.TxBasePort),
|
|
||||||
TotalPorts: uint32(f.multiPort.TxPorts),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(via.UdpAddr.Port()) {
|
|
||||||
// The other side sent us a handshake from a different port, make sure
|
|
||||||
// we send responses back to the BasePort
|
|
||||||
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), uint16(hs.Details.InitiatorMultiPort.BasePort))
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo := &HostInfo{
|
|
||||||
ConnectionState: ci,
|
|
||||||
localIndexId: myIndex,
|
|
||||||
remoteIndexId: hs.Details.InitiatorIndex,
|
|
||||||
vpnAddrs: vpnAddrs,
|
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
|
||||||
lastHandshakeTime: hs.Details.Time,
|
|
||||||
multiportTx: multiportTx,
|
|
||||||
multiportRx: multiportRx,
|
|
||||||
relayState: RelayState{
|
|
||||||
relays: nil,
|
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
msgRxL := f.l.WithFields(m{
|
|
||||||
"vpnAddrs": vpnAddrs,
|
|
||||||
"from": via,
|
|
||||||
"certName": certName,
|
|
||||||
"certVersion": certVersion,
|
|
||||||
"fingerprint": fingerprint,
|
|
||||||
"issuer": issuer,
|
|
||||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
|
||||||
"responderIndex": hs.Details.ResponderIndex,
|
|
||||||
"remoteIndex": h.RemoteIndex,
|
|
||||||
"multiportTx": multiportTx,
|
|
||||||
"multiportRx": multiportRx,
|
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
|
||||||
})
|
|
||||||
|
|
||||||
if anyVpnAddrsInCommon {
|
|
||||||
msgRxL.Info("Handshake message received")
|
|
||||||
} else {
|
|
||||||
//todo warn if not lighthouse or relay?
|
|
||||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
|
||||||
if hs.Details.Cert == nil {
|
|
||||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.Details.CertVersion = uint32(ci.myCert.Version())
|
|
||||||
// Update the time in case their clock is way off from ours
|
|
||||||
hs.Details.Time = uint64(time.Now().UnixNano())
|
|
||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
|
||||||
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
|
||||||
return
|
|
||||||
} else if dKey == nil || eKey == nil {
|
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:]))
|
|
||||||
copy(hostinfo.HandshakePacket[0], packet[header.Len:])
|
|
||||||
|
|
||||||
// Regardless of whether you are the sender or receiver, you should arrive here
|
|
||||||
// and complete standing up the connection.
|
|
||||||
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
|
|
||||||
copy(hostinfo.HandshakePacket[2], msg)
|
|
||||||
|
|
||||||
// We are sending handshake packet 2, so we don't expect to receive
|
|
||||||
// handshake packet 2 from the initiator.
|
|
||||||
ci.window.Update(f.l, 2)
|
|
||||||
|
|
||||||
ci.peerCert = remoteCert
|
|
||||||
ci.dKey = NewNebulaCipherState(dKey)
|
|
||||||
ci.eKey = NewNebulaCipherState(eKey)
|
|
||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
|
||||||
if !via.IsRelayed {
|
|
||||||
hostinfo.SetRemote(via.UdpAddr)
|
|
||||||
}
|
|
||||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
|
||||||
if err != nil {
|
|
||||||
switch err {
|
|
||||||
case ErrAlreadySeen:
|
|
||||||
if hostinfo.multiportRx {
|
|
||||||
// The other host is sending to us with multiport, so only grab the IP
|
|
||||||
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port())
|
|
||||||
}
|
|
||||||
// Update remote if preferred
|
|
||||||
if existing.SetRemoteIfPreferred(f.hostMap, via) {
|
|
||||||
// Send a test packet to ensure the other side has also switched to
|
|
||||||
// the preferred remote
|
|
||||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
|
||||||
}
|
|
||||||
|
|
||||||
msg = existing.HandshakePacket[2]
|
|
||||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
|
||||||
if !via.IsRelayed {
|
|
||||||
err := f.outside.WriteTo(msg, via.UdpAddr)
|
|
||||||
if multiportTx {
|
|
||||||
// TODO remove alloc here
|
|
||||||
raw := make([]byte, len(msg)+udp.RawOverhead)
|
|
||||||
copy(raw[udp.RawOverhead:], msg)
|
|
||||||
err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr)
|
|
||||||
} else {
|
|
||||||
err = f.outside.WriteTo(msg, via.UdpAddr)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
|
||||||
WithError(err).Error("Failed to send handshake message")
|
|
||||||
} else {
|
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
|
||||||
Info("Handshake message sent")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
if via.relay == nil {
|
|
||||||
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
|
||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
|
||||||
Info("Handshake message sent")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case ErrExistingHostInfo:
|
|
||||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
|
||||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Info("Handshake too old")
|
|
||||||
|
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
|
||||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
|
||||||
return
|
|
||||||
case ErrLocalIndexCollision:
|
|
||||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
|
|
||||||
Error("Failed to add HostInfo due to localIndex collision")
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
|
||||||
// And we forget to update it here
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Error("Failed to add HostInfo to HostMap")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do the send
|
|
||||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
|
||||||
if !via.IsRelayed {
|
|
||||||
if multiportTx {
|
|
||||||
// TODO remove alloc here
|
|
||||||
raw := make([]byte, len(msg)+udp.RawOverhead)
|
|
||||||
copy(raw[udp.RawOverhead:], msg)
|
|
||||||
err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr)
|
|
||||||
} else {
|
|
||||||
err = f.outside.WriteTo(msg, via.UdpAddr)
|
|
||||||
}
|
|
||||||
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Error("Failed to send handshake")
|
|
||||||
} else {
|
|
||||||
log.Info("Handshake message sent")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if via.relay == nil {
|
|
||||||
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
|
||||||
// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
|
|
||||||
// it's correctly marked as working.
|
|
||||||
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
|
|
||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
Info("Handshake message sent")
|
|
||||||
}
|
|
||||||
|
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
|
||||||
|
|
||||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
|
|
||||||
if hh == nil {
|
|
||||||
// Nothing here to tear down, got a bogus stage 2 packet
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
hh.Lock()
|
|
||||||
defer hh.Unlock()
|
|
||||||
|
|
||||||
hostinfo := hh.hostinfo
|
|
||||||
if !via.IsRelayed {
|
|
||||||
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ci := hostinfo.ConnectionState
|
|
||||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
|
||||||
Error("Failed to call noise.ReadMessage")
|
|
||||||
|
|
||||||
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
|
|
||||||
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
|
|
||||||
// near future
|
|
||||||
return false
|
|
||||||
} else if dKey == nil || eKey == nil {
|
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
Error("Noise did not arrive at a key")
|
|
||||||
|
|
||||||
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
|
|
||||||
// the handshake state machine. Tear it down
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
hs := &NebulaHandshake{}
|
|
||||||
err = hs.Unmarshal(msg)
|
|
||||||
if err != nil || hs.Details == nil {
|
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
|
||||||
|
|
||||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if (f.multiPort.Tx || f.multiPort.Rx) && hs.Details.ResponderMultiPort != nil {
|
|
||||||
hostinfo.multiportTx = hs.Details.ResponderMultiPort.RxSupported && f.multiPort.Tx
|
|
||||||
hostinfo.multiportRx = hs.Details.ResponderMultiPort.TxSupported && f.multiPort.Rx
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(via.UdpAddr.Port()) {
|
|
||||||
// The other side sent us a handshake from a different port, make sure
|
|
||||||
// we send responses back to the BasePort
|
|
||||||
via.UdpAddr = netip.AddrPortFrom(
|
|
||||||
via.UdpAddr.Addr(),
|
|
||||||
uint16(hs.Details.ResponderMultiPort.BasePort),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("from", via).
|
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
Info("Handshake did not contain a certificate")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
|
||||||
if err != nil {
|
|
||||||
fp, err := rc.Fingerprint()
|
|
||||||
if err != nil {
|
|
||||||
fp = "<error generating certificate fingerprint>"
|
|
||||||
}
|
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("from", via).
|
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
WithField("certFingerprint", fp).
|
|
||||||
WithField("certVpnNetworks", rc.Networks())
|
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
e = e.WithField("cert", rc)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.Info("Invalid certificate from host")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
|
||||||
f.l.WithField("from", via).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
|
||||||
f.l.WithError(err).WithField("from", via).
|
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("cert", remoteCert).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
Info("No networks in certificate")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
vpnNetworks := remoteCert.Certificate.Networks()
|
|
||||||
certName := remoteCert.Certificate.Name()
|
|
||||||
certVersion := remoteCert.Certificate.Version()
|
|
||||||
fingerprint := remoteCert.Fingerprint
|
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
|
||||||
|
|
||||||
hostinfo.remoteIndexId = hs.Details.ResponderIndex
|
|
||||||
hostinfo.lastHandshakeTime = hs.Details.Time
|
|
||||||
|
|
||||||
// Store their cert and our symmetric keys
|
|
||||||
ci.peerCert = remoteCert
|
|
||||||
ci.dKey = NewNebulaCipherState(dKey)
|
|
||||||
ci.eKey = NewNebulaCipherState(eKey)
|
|
||||||
|
|
||||||
// Make sure the current udpAddr being used is set for responding
|
|
||||||
if !via.IsRelayed {
|
|
||||||
hostinfo.SetRemote(via.UdpAddr)
|
|
||||||
} else {
|
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
correctHostResponded := false
|
|
||||||
anyVpnAddrsInCommon := false
|
|
||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
|
||||||
for i, network := range vpnNetworks {
|
|
||||||
vpnAddrs[i] = network.Addr()
|
|
||||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
|
||||||
anyVpnAddrsInCommon = true
|
|
||||||
}
|
|
||||||
if hostinfo.vpnAddrs[0] == network.Addr() {
|
|
||||||
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
|
||||||
correctHostResponded = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the right host responded
|
|
||||||
if !correctHostResponded {
|
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
|
||||||
WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
Info("Incorrect host responded to handshake")
|
|
||||||
|
|
||||||
// Release our old handshake from pending, it should not continue
|
|
||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
|
||||||
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
|
||||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
|
||||||
// Block the current used address
|
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
|
||||||
newHH.hostinfo.remotes.BlockRemote(via)
|
|
||||||
|
|
||||||
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
|
|
||||||
WithField("vpnNetworks", vpnNetworks).
|
|
||||||
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
|
|
||||||
Info("Blocked addresses for handshakes")
|
|
||||||
|
|
||||||
// Swap the packet store to benefit the original intended recipient
|
|
||||||
newHH.packetStore = hh.packetStore
|
|
||||||
hh.packetStore = []*cachedPacket{}
|
|
||||||
|
|
||||||
// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
|
|
||||||
hostinfo.vpnAddrs = vpnAddrs
|
|
||||||
f.sendCloseTunnel(hostinfo)
|
|
||||||
})
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark packet 2 as seen so it doesn't show up as missed
|
|
||||||
ci.window.Update(f.l, 2)
|
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
|
||||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
WithField("durationNs", duration).
|
|
||||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
|
||||||
WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx)
|
|
||||||
if anyVpnAddrsInCommon {
|
|
||||||
msgRxL.Info("Handshake message received")
|
|
||||||
} else {
|
|
||||||
//todo warn if not lighthouse or relay?
|
|
||||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build up the radix for the firewall if we have subnets in the cert
|
|
||||||
hostinfo.vpnAddrs = vpnAddrs
|
|
||||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
|
||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(hh.packetStore) > 0 {
|
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
out := make([]byte, mtu)
|
|
||||||
for _, cp := range hh.packetStore {
|
|
||||||
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
|
|
||||||
}
|
|
||||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
|
||||||
f.metricHandshakes.Update(duration)
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
@@ -27,7 +28,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||||||
initiatingVersion: cert.Version1,
|
initiatingVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1Credential: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
||||||
@@ -100,3 +101,137 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
|
|||||||
func (mw *mockEncWriter) GetCertState() *CertState {
|
func (mw *mockEncWriter) GetCertState() *CertState {
|
||||||
return &CertState{initiatingVersion: cert.Version2}
|
return &CertState{initiatingVersion: cert.Version2}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidatePeerCert(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
|
||||||
|
myNetwork := netip.MustParsePrefix("10.0.0.1/24")
|
||||||
|
myAddrTable := new(bart.Lite)
|
||||||
|
myAddrTable.Insert(netip.PrefixFrom(myNetwork.Addr(), myNetwork.Addr().BitLen()))
|
||||||
|
myNetTable := new(bart.Lite)
|
||||||
|
myNetTable.Insert(myNetwork.Masked())
|
||||||
|
|
||||||
|
newHM := func() *HandshakeManager {
|
||||||
|
hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig)
|
||||||
|
hm.f = &Interface{
|
||||||
|
handshakeManager: hm,
|
||||||
|
pki: &PKI{},
|
||||||
|
l: l,
|
||||||
|
myVpnAddrsTable: myAddrTable,
|
||||||
|
myVpnNetworksTable: myNetTable,
|
||||||
|
lightHouse: hm.lightHouse,
|
||||||
|
}
|
||||||
|
return hm
|
||||||
|
}
|
||||||
|
|
||||||
|
cached := func(networks ...netip.Prefix) *cert.CachedCertificate {
|
||||||
|
return &cert.CachedCertificate{
|
||||||
|
Certificate: &dummyCert{name: "peer", networks: networks},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
via := ViaSender{
|
||||||
|
UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"),
|
||||||
|
IsRelayed: true, // skip the remote allow list (covered separately)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("addr inside our networks sets anyVpnAddrsInCommon", func(t *testing.T) {
|
||||||
|
hm := newHM()
|
||||||
|
// 10.0.0.2 falls inside our 10.0.0.0/24
|
||||||
|
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.2/24")))
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.True(t, common)
|
||||||
|
assert.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.2")}, addrs)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("addr outside our networks leaves anyVpnAddrsInCommon false", func(t *testing.T) {
|
||||||
|
hm := newHM()
|
||||||
|
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("192.168.1.5/24")))
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.False(t, common)
|
||||||
|
assert.Equal(t, []netip.Addr{netip.MustParseAddr("192.168.1.5")}, addrs)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("any matching network is enough", func(t *testing.T) {
|
||||||
|
hm := newHM()
|
||||||
|
addrs, common, ok := hm.validatePeerCert(via, cached(
|
||||||
|
netip.MustParsePrefix("192.168.1.5/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.42/24"),
|
||||||
|
))
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.True(t, common)
|
||||||
|
assert.Len(t, addrs, 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("self-handshake is rejected", func(t *testing.T) {
|
||||||
|
hm := newHM()
|
||||||
|
// 10.0.0.1 is in myVpnAddrsTable
|
||||||
|
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.1/24")))
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.False(t, common)
|
||||||
|
assert.Nil(t, addrs)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cert with no networks is rejected", func(t *testing.T) {
|
||||||
|
hm := newHM()
|
||||||
|
addrs, common, ok := hm.validatePeerCert(via, cached())
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.False(t, common)
|
||||||
|
assert.Nil(t, addrs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleIncomingDispatch(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
|
||||||
|
newHM := func() *HandshakeManager {
|
||||||
|
hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig)
|
||||||
|
hm.f = &Interface{
|
||||||
|
handshakeManager: hm,
|
||||||
|
pki: &PKI{},
|
||||||
|
l: l,
|
||||||
|
}
|
||||||
|
return hm
|
||||||
|
}
|
||||||
|
|
||||||
|
via := ViaSender{
|
||||||
|
UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"),
|
||||||
|
IsRelayed: true, // bypass remote allow list
|
||||||
|
}
|
||||||
|
|
||||||
|
// A packet body of zero length is fine for these tests: dispatch is
|
||||||
|
// gated on header fields, and we assert that we never reach noise/cert
|
||||||
|
// processing for any of the malformed shapes here.
|
||||||
|
pkt := make([]byte, header.Len)
|
||||||
|
|
||||||
|
t.Run("unsupported subtype dropped", func(t *testing.T) {
|
||||||
|
hm := newHM()
|
||||||
|
h := &header.H{Type: header.Handshake, Subtype: header.MessageSubType(99), MessageCounter: 1}
|
||||||
|
hm.HandleIncoming(via, pkt, h)
|
||||||
|
assert.Empty(t, hm.indexes, "no pending handshake should be created")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stage-1 with non-zero RemoteIndex dropped", func(t *testing.T) {
|
||||||
|
hm := newHM()
|
||||||
|
h := &header.H{
|
||||||
|
Type: header.Handshake,
|
||||||
|
Subtype: header.HandshakeIXPSK0,
|
||||||
|
RemoteIndex: 0xdeadbeef,
|
||||||
|
MessageCounter: 1,
|
||||||
|
}
|
||||||
|
hm.HandleIncoming(via, pkt, h)
|
||||||
|
assert.Empty(t, hm.indexes, "spoofed stage-1 must not create a pending machine")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("continuation with no matching pending index dropped", func(t *testing.T) {
|
||||||
|
hm := newHM()
|
||||||
|
h := &header.H{
|
||||||
|
Type: header.Handshake,
|
||||||
|
Subtype: header.HandshakeIXPSK0,
|
||||||
|
RemoteIndex: 0xcafef00d,
|
||||||
|
MessageCounter: 2,
|
||||||
|
}
|
||||||
|
hm.HandleIncoming(via, pkt, h)
|
||||||
|
assert.Empty(t, hm.indexes, "orphan stage-2 must not create state")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -174,6 +174,10 @@ func (h *H) SubTypeName() string {
|
|||||||
return SubTypeName(h.Type, h.Subtype)
|
return SubTypeName(h.Type, h.Subtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *H) IsValidSubType() bool {
|
||||||
|
return IsValidSubType(h.Type, h.Subtype)
|
||||||
|
}
|
||||||
|
|
||||||
// SubTypeName will transform a nebula message sub type into a human string
|
// SubTypeName will transform a nebula message sub type into a human string
|
||||||
func SubTypeName(t MessageType, s MessageSubType) string {
|
func SubTypeName(t MessageType, s MessageSubType) string {
|
||||||
if n, ok := subTypeMap[t]; ok {
|
if n, ok := subTypeMap[t]; ok {
|
||||||
@@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string {
|
|||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsValidSubType(t MessageType, s MessageSubType) bool {
|
||||||
|
if n, ok := subTypeMap[t]; ok {
|
||||||
|
if _, ok := (*n)[s]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// NewHeader turns bytes into a header
|
// NewHeader turns bytes into a header
|
||||||
func NewHeader(b []byte) (*H, error) {
|
func NewHeader(b []byte) (*H, error) {
|
||||||
h := new(H)
|
h := new(H)
|
||||||
|
|||||||
80
hostmap.go
80
hostmap.go
@@ -1,9 +1,11 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -13,10 +15,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||||
@@ -60,7 +62,7 @@ type HostMap struct {
|
|||||||
RemoteIndexes map[uint32]*HostInfo
|
RemoteIndexes map[uint32]*HostInfo
|
||||||
Hosts map[netip.Addr]*HostInfo
|
Hosts map[netip.Addr]*HostInfo
|
||||||
preferredRanges atomic.Pointer[[]netip.Prefix]
|
preferredRanges atomic.Pointer[[]netip.Prefix]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
|
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
|
||||||
@@ -319,7 +321,7 @@ type cachedPacketMetrics struct {
|
|||||||
dropped metrics.Counter
|
dropped metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap {
|
||||||
hm := newHostMap(l)
|
hm := newHostMap(l)
|
||||||
|
|
||||||
hm.reload(c, true)
|
hm.reload(c, true)
|
||||||
@@ -327,13 +329,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
|||||||
hm.reload(c, false)
|
hm.reload(c, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
l.WithField("preferredRanges", hm.GetPreferredRanges()).
|
l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges())
|
||||||
Info("Main HostMap created")
|
|
||||||
|
|
||||||
return hm
|
return hm
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostMap(l *logrus.Logger) *HostMap {
|
func newHostMap(l *slog.Logger) *HostMap {
|
||||||
return &HostMap{
|
return &HostMap{
|
||||||
Indexes: map[uint32]*HostInfo{},
|
Indexes: map[uint32]*HostInfo{},
|
||||||
Relays: map[uint32]*HostInfo{},
|
Relays: map[uint32]*HostInfo{},
|
||||||
@@ -352,7 +353,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
|
|||||||
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
|
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
|
hm.l.Warn("Failed to parse preferred ranges, ignoring",
|
||||||
|
"error", err,
|
||||||
|
"range", rawPreferredRanges,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -361,7 +365,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
|
|||||||
|
|
||||||
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
|
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
|
||||||
if !initial {
|
if !initial {
|
||||||
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
|
hm.l.Info("preferred_ranges changed",
|
||||||
|
"oldPreferredRanges", *oldRanges,
|
||||||
|
"newPreferredRanges", preferredRanges,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -494,10 +501,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad
|
|||||||
hm.Indexes = map[uint32]*HostInfo{}
|
hm.Indexes = map[uint32]*HostInfo{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
|
hm.l.Debug("Hostmap hostInfo deleted",
|
||||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
"hostMap", m{"mapTotalSize": len(hm.Hosts),
|
||||||
Debug("Hostmap hostInfo deleted")
|
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLastHostinfo {
|
if isLastHostinfo {
|
||||||
@@ -610,9 +618,9 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
|
|||||||
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
|
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
|
||||||
// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
|
// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
|
||||||
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||||
if f.serveDns {
|
if f.dnsServer != nil {
|
||||||
remoteCert := hostinfo.ConnectionState.peerCert
|
remoteCert := hostinfo.ConnectionState.peerCert
|
||||||
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
f.dnsServer.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
||||||
}
|
}
|
||||||
for _, addr := range hostinfo.vpnAddrs {
|
for _, addr := range hostinfo.vpnAddrs {
|
||||||
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
||||||
@@ -621,10 +629,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
|||||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
hm.l.Debug("Hostmap vpnIp added",
|
||||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
|
"hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
||||||
Debug("Hostmap vpnIp added")
|
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -790,18 +799,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
// logger returns a derived slog.Logger with per-hostinfo fields pre-bound.
|
||||||
|
func (i *HostInfo) logger(l *slog.Logger) *slog.Logger {
|
||||||
if i == nil {
|
if i == nil {
|
||||||
return logrus.NewEntry(l)
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
li := l.WithField("vpnAddrs", i.vpnAddrs).
|
li := l.With(
|
||||||
WithField("localIndex", i.localIndexId).
|
"vpnAddrs", i.vpnAddrs,
|
||||||
WithField("remoteIndex", i.remoteIndexId)
|
"localIndex", i.localIndexId,
|
||||||
|
"remoteIndex", i.remoteIndexId,
|
||||||
|
)
|
||||||
|
|
||||||
if connState := i.ConnectionState; connState != nil {
|
if connState := i.ConnectionState; connState != nil {
|
||||||
if peerCert := connState.peerCert; peerCert != nil {
|
if peerCert := connState.peerCert; peerCert != nil {
|
||||||
li = li.WithField("certName", peerCert.Certificate.Name())
|
li = li.With("certName", peerCert.Certificate.Name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -810,14 +822,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
|||||||
|
|
||||||
// Utility functions
|
// Utility functions
|
||||||
|
|
||||||
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||||
//FIXME: This function is pretty garbage
|
//FIXME: This function is pretty garbage
|
||||||
var finalAddrs []netip.Addr
|
var finalAddrs []netip.Addr
|
||||||
ifaces, _ := net.Interfaces()
|
ifaces, _ := net.Interfaces()
|
||||||
for _, i := range ifaces {
|
for _, i := range ifaces {
|
||||||
allow := allowList.AllowName(i.Name)
|
allow := allowList.AllowName(i.Name)
|
||||||
if l.Level >= logrus.TraceLevel {
|
if l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName")
|
l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName",
|
||||||
|
"interfaceName", i.Name,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -835,8 +850,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !addr.IsValid() {
|
if !addr.IsValid() {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
|
l.Debug("addr was invalid", "localAddr", rawAddr)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -844,8 +859,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
|||||||
|
|
||||||
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
|
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
|
||||||
isAllowed := allowList.Allow(addr)
|
isAllowed := allowList.Allow(addr)
|
||||||
if l.Level >= logrus.TraceLevel {
|
if l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
|
l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow",
|
||||||
|
"localAddr", addr,
|
||||||
|
"allowed", isAllowed,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if !isAllowed {
|
if !isAllowed {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||||||
|
|
||||||
func TestHostMap_reload(t *testing.T) {
|
func TestHostMap_reload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(test.NewLogger())
|
||||||
|
|
||||||
hm := NewHostMapFromConfig(l, c)
|
hm := NewHostMapFromConfig(l, c)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build e2e_testing
|
//go:build e2e_testing
|
||||||
// +build e2e_testing
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
119
inside.go
119
inside.go
@@ -1,9 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
@@ -15,8 +16,11 @@ import (
|
|||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
f.l.Debug("Error while validating outbound packet",
|
||||||
|
"packet", packet,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -36,7 +40,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
if immediatelyForwardToSelf {
|
if immediatelyForwardToSelf {
|
||||||
_, err := f.readers[q].Write(packet)
|
_, err := f.readers[q].Write(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to forward to tun")
|
f.l.Error("Failed to forward to tun", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Otherwise, drop. On linux, we should never see these packets - Linux
|
// Otherwise, drop. On linux, we should never see these packets - Linux
|
||||||
@@ -55,10 +59,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
|
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
|
||||||
WithField("fwPacket", fwPacket).
|
"vpnAddr", fwPacket.RemoteAddr,
|
||||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
"fwPacket", fwPacket,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -73,11 +78,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).
|
hostinfo.logger(f.l).Debug("dropping outbound packet",
|
||||||
WithField("fwPacket", fwPacket).
|
"fwPacket", fwPacket,
|
||||||
WithField("reason", dropReason).
|
"reason", dropReason,
|
||||||
Debugln("dropping outbound packet")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -94,7 +99,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
|||||||
|
|
||||||
_, err := f.readers[q].Write(out)
|
_, err := f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.Error("Failed to write to tun", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,11 +114,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(out) > iputil.MaxRejectPacketSize {
|
if len(out) > iputil.MaxRejectPacketSize {
|
||||||
if f.l.GetLevel() >= logrus.InfoLevel {
|
if f.l.Enabled(context.Background(), slog.LevelInfo) {
|
||||||
f.l.
|
f.l.Info("rejectOutside: packet too big, not sending",
|
||||||
WithField("packet", packet).
|
"packet", packet,
|
||||||
WithField("outPacket", out).
|
"outPacket", out,
|
||||||
Info("rejectOutside: packet too big, not sending")
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -185,10 +190,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
|
|||||||
// This would also need to interact with unsafe_route updates through reloading the config or
|
// This would also need to interact with unsafe_route updates through reloading the config or
|
||||||
// use of the use_system_route_table option
|
// use of the use_system_route_table option
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("destination", destinationAddr).
|
f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways",
|
||||||
WithField("originalGateway", gatewayAddr).
|
"destination", destinationAddr,
|
||||||
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
|
"originalGateway", gatewayAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range gateways {
|
for i := range gateways {
|
||||||
@@ -214,17 +220,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
fp := &firewall.Packet{}
|
fp := &firewall.Packet{}
|
||||||
err := newPacket(p, false, fp)
|
err := newPacket(p, false, fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("fwPacket", fp).
|
f.l.Debug("dropping cached packet",
|
||||||
WithField("reason", dropReason).
|
"fwPacket", fp,
|
||||||
Debugln("dropping cached packet")
|
"reason", dropReason,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -240,9 +247,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message
|
|||||||
})
|
})
|
||||||
|
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("vpnAddr", vpnAddr).
|
f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes",
|
||||||
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
|
"vpnAddr", vpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -298,12 +306,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
if noiseutil.EncryptLockNeeded {
|
if noiseutil.EncryptLockNeeded {
|
||||||
via.ConnectionState.writeLock.Unlock()
|
via.ConnectionState.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
via.logger(f.l).
|
via.logger(f.l).Error("SendVia out buffer not large enough for relay",
|
||||||
WithField("outCap", cap(out)).
|
"outCap", cap(out),
|
||||||
WithField("payloadLen", len(ad)).
|
"payloadLen", len(ad),
|
||||||
WithField("headerLen", len(out)).
|
"headerLen", len(out),
|
||||||
WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()).
|
"cipherOverhead", via.ConnectionState.eKey.Overhead(),
|
||||||
Error("SendVia out buffer not large enough for relay")
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,12 +331,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
via.ConnectionState.writeLock.Unlock()
|
via.ConnectionState.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
|
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = f.writers[0].WriteTo(out, via.remote)
|
err = f.writers[0].WriteTo(out, via.remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
|
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
|
||||||
}
|
}
|
||||||
f.connectionManager.RelayUsed(relay.LocalIndex)
|
f.connectionManager.RelayUsed(relay.LocalIndex)
|
||||||
}
|
}
|
||||||
@@ -384,8 +392,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||||
hostinfo.lastRebindCount = f.rebindCount
|
hostinfo.lastRebindCount = f.rebindCount
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
f.l.Debug("Lighthouse update triggered for punch due to rebind counter",
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -395,10 +405,12 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
ci.writeLock.Unlock()
|
ci.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
|
||||||
WithField("udpAddr", remote).WithField("counter", c).
|
"error", err,
|
||||||
WithField("attemptedCounter", c).
|
"udpAddr", remote,
|
||||||
Error("Failed to encrypt outgoing packet")
|
"counter", c,
|
||||||
|
"attemptedCounter", c,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -411,8 +423,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
err = f.writers[q].WriteTo(out, remote)
|
err = f.writers[q].WriteTo(out, remote)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
"error", err,
|
||||||
|
"udpAddr", remote,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else if hostinfo.remote.IsValid() {
|
} else if hostinfo.remote.IsValid() {
|
||||||
if multiport {
|
if multiport {
|
||||||
@@ -423,8 +437,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
"error", err,
|
||||||
|
"udpAddr", remote,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Try to send via a relay
|
// Try to send via a relay
|
||||||
@@ -432,7 +448,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.relayState.DeleteRelay(relayIP)
|
hostinfo.relayState.DeleteRelay(relayIP)
|
||||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
|
||||||
|
"relay", relayIP,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||||
// +build darwin dragonfly freebsd netbsd openbsd
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd
|
//go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd
|
||||||
// +build !darwin,!dragonfly,!freebsd,!netbsd,!openbsd
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
158
interface.go
158
interface.go
@@ -5,15 +5,15 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"sync"
|
||||||
"runtime"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@@ -30,7 +30,7 @@ type InterfaceConfig struct {
|
|||||||
pki *PKI
|
pki *PKI
|
||||||
Cipher string
|
Cipher string
|
||||||
Firewall *Firewall
|
Firewall *Firewall
|
||||||
ServeDns bool
|
DnsServer *dnsServer
|
||||||
HandshakeManager *HandshakeManager
|
HandshakeManager *HandshakeManager
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
@@ -47,7 +47,7 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
@@ -58,7 +58,7 @@ type Interface struct {
|
|||||||
firewall *Firewall
|
firewall *Firewall
|
||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
handshakeManager *HandshakeManager
|
handshakeManager *HandshakeManager
|
||||||
serveDns bool
|
dnsServer *dnsServer
|
||||||
createTime time.Time
|
createTime time.Time
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
myBroadcastAddrsTable *bart.Lite
|
myBroadcastAddrsTable *bart.Lite
|
||||||
@@ -86,17 +86,25 @@ type Interface struct {
|
|||||||
|
|
||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
udpRaw *udp.RawConn
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// fatalErr holds the first unexpected reader error that caused shutdown.
|
||||||
|
// nil means "no fatal error" (yet)
|
||||||
|
fatalErr atomic.Pointer[error]
|
||||||
|
// triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr
|
||||||
|
triggerShutdown func()
|
||||||
|
|
||||||
|
udpRaw *udp.RawConn
|
||||||
multiPort MultiPortConfig
|
multiPort MultiPortConfig
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type MultiPortConfig struct {
|
type MultiPortConfig struct {
|
||||||
@@ -176,12 +184,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
|
ctx: ctx,
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
hostMap: c.HostMap,
|
hostMap: c.HostMap,
|
||||||
outside: c.Outside,
|
outside: c.Outside,
|
||||||
inside: c.Inside,
|
inside: c.Inside,
|
||||||
firewall: c.Firewall,
|
firewall: c.Firewall,
|
||||||
serveDns: c.ServeDns,
|
dnsServer: c.DnsServer,
|
||||||
handshakeManager: c.HandshakeManager,
|
handshakeManager: c.HandshakeManager,
|
||||||
createTime: time.Now(),
|
createTime: time.Now(),
|
||||||
lightHouse: c.lightHouse,
|
lightHouse: c.lightHouse,
|
||||||
@@ -222,18 +231,21 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
// activate creates the interface on the host. After the interface is created, any
|
// activate creates the interface on the host. After the interface is created, any
|
||||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||||
// the interface isn't going to process anything until run() is called.
|
// the interface isn't going to process anything until run() is called.
|
||||||
func (f *Interface) activate() {
|
func (f *Interface) activate() error {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to get udp listen address")
|
f.l.Error("Failed to get udp listen address", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
|
f.l.Info("Nebula interface is active",
|
||||||
WithField("build", f.version).WithField("udpAddr", addr).
|
"interface", f.inside.Name(),
|
||||||
WithField("boringcrypto", boringEnabled()).
|
"networks", f.myVpnNetworks,
|
||||||
Info("Nebula interface is active")
|
"build", f.version,
|
||||||
|
"udpAddr", addr,
|
||||||
|
"boringcrypto", boringEnabled(),
|
||||||
|
)
|
||||||
|
|
||||||
if f.routines > 1 {
|
if f.routines > 1 {
|
||||||
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
||||||
@@ -252,33 +264,58 @@ func (f *Interface) activate() {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.inside.Activate(); err != nil {
|
f.wg.Add(1) // for us to wait on Close() to return
|
||||||
|
if err = f.inside.Activate(); err != nil {
|
||||||
|
f.wg.Done()
|
||||||
f.inside.Close()
|
f.inside.Close()
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) run() {
|
func (f *Interface) run() (func() error, error) {
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenOut(i)
|
f.wg.Go(func() {
|
||||||
|
f.listenOut(i)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenIn(f.readers[i], i)
|
f.wg.Go(func() {
|
||||||
|
f.listenIn(f.readers[i], i)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() error {
|
||||||
|
f.wg.Wait()
|
||||||
|
if e := f.fatalErr.Load(); e != nil {
|
||||||
|
return *e
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one
|
||||||
|
func (f *Interface) onFatal(err error) {
|
||||||
|
swapped := f.fatalErr.CompareAndSwap(nil, &err)
|
||||||
|
if !swapped {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.triggerShutdown != nil {
|
||||||
|
f.triggerShutdown()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[i]
|
||||||
@@ -286,42 +323,47 @@ func (f *Interface) listenOut(i int) {
|
|||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
plaintext := make([]byte, udp.MTU)
|
plaintext := make([]byte, udp.MTU)
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err != nil && !f.closed.Load() {
|
||||||
|
f.l.Error("Error while reading inbound packet, closing", "error", err)
|
||||||
|
f.onFatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.l.Debug("underlay reader is done", "reader", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
|
||||||
|
|
||||||
packet := make([]byte, mtu)
|
packet := make([]byte, mtu)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
n, err := reader.Read(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
if !f.closed.Load() {
|
||||||
return
|
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
|
||||||
|
f.onFatal(err)
|
||||||
}
|
}
|
||||||
|
break
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
|
||||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
|
||||||
os.Exit(2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.l.Debug("overlay reader is done", "reader", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
@@ -341,7 +383,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
|
|||||||
if initial || c.HasChanged("pki.disconnect_invalid") {
|
if initial || c.HasChanged("pki.disconnect_invalid") {
|
||||||
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
|
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
|
||||||
if !initial {
|
if !initial {
|
||||||
f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
|
f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -355,7 +397,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
|
|
||||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
f.l.Error("Error while creating firewall during reload", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -368,10 +410,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
||||||
// safe and just reset conntrack in this case.
|
// safe and just reset conntrack in this case.
|
||||||
if fw.rulesVersion == 0 {
|
if fw.rulesVersion == 0 {
|
||||||
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
|
f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack",
|
||||||
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
|
"firewallHashes", fw.GetRuleHashes(),
|
||||||
WithField("rulesVersion", fw.rulesVersion).
|
"oldFirewallHashes", oldFw.GetRuleHashes(),
|
||||||
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
"rulesVersion", fw.rulesVersion,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
fw.Conntrack = conntrack
|
fw.Conntrack = conntrack
|
||||||
}
|
}
|
||||||
@@ -379,10 +422,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
f.firewall = fw
|
f.firewall = fw
|
||||||
|
|
||||||
oldFw.Destroy()
|
oldFw.Destroy()
|
||||||
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
|
f.l.Info("New firewall has been installed",
|
||||||
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
|
"firewallHashes", fw.GetRuleHashes(),
|
||||||
WithField("rulesVersion", fw.rulesVersion).
|
"oldFirewallHashes", oldFw.GetRuleHashes(),
|
||||||
Info("New firewall has been installed")
|
"rulesVersion", fw.rulesVersion,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadSendRecvError(c *config.C) {
|
func (f *Interface) reloadSendRecvError(c *config.C) {
|
||||||
@@ -404,8 +448,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
|
f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String())
|
||||||
Info("Loaded send_recv_error config")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -428,8 +471,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()).
|
f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String())
|
||||||
Info("Loaded accept_recv_error config")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -505,15 +547,23 @@ func (f *Interface) GetCertState() *CertState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
|
var errs []error
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
for _, u := range f.writers {
|
// Release the udp readers
|
||||||
|
for i, u := range f.writers {
|
||||||
err := u.Close()
|
err := u.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Error while closing udp socket")
|
f.l.Error("Error while closing udp socket", "error", err, "writer", i)
|
||||||
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the tun device
|
// Release the tun device (closing the tun also closes all readers)
|
||||||
return f.inside.Close()
|
closeErr := f.inside.Close()
|
||||||
|
if closeErr != nil {
|
||||||
|
errs = append(errs, closeErr)
|
||||||
|
}
|
||||||
|
f.wg.Done()
|
||||||
|
return errors.Join(errs...)
|
||||||
}
|
}
|
||||||
|
|||||||
259
lighthouse.go
259
lighthouse.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -15,10 +16,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
@@ -69,18 +70,19 @@ type LightHouse struct {
|
|||||||
// Addr's of relays that can be used by peers to access me
|
// Addr's of relays that can be used by peers to access me
|
||||||
relaysForMe atomic.Pointer[[]netip.Addr]
|
relaysForMe atomic.Pointer[[]netip.Addr]
|
||||||
|
|
||||||
queryChan chan netip.Addr
|
updateTrigger chan struct{}
|
||||||
|
queryChan chan netip.Addr
|
||||||
|
|
||||||
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
|
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
|
||||||
|
|
||||||
metrics *MessageMetrics
|
metrics *MessageMetrics
|
||||||
metricHolepunchTx metrics.Counter
|
metricHolepunchTx metrics.Counter
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
||||||
// addrMap should be nil unless this is during a config reload
|
// addrMap should be nil unless this is during a config reload
|
||||||
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
||||||
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||||
nebulaPort := uint32(c.GetInt("listen.port", 0))
|
nebulaPort := uint32(c.GetInt("listen.port", 0))
|
||||||
if amLighthouse && nebulaPort == 0 {
|
if amLighthouse && nebulaPort == 0 {
|
||||||
@@ -105,6 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||||||
nebulaPort: nebulaPort,
|
nebulaPort: nebulaPort,
|
||||||
punchConn: pc,
|
punchConn: pc,
|
||||||
punchy: p,
|
punchy: p,
|
||||||
|
updateTrigger: make(chan struct{}, 1),
|
||||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@@ -131,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||||||
case *util.ContextualError:
|
case *util.ContextualError:
|
||||||
v.Log(l)
|
v.Log(l)
|
||||||
case error:
|
case error:
|
||||||
l.WithError(err).Error("failed to reload lighthouse")
|
l.Error("failed to reload lighthouse", "error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -203,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
||||||
addr := addrs[0].Unmap()
|
addr := addrs[0].Unmap()
|
||||||
if lh.myVpnNetworksTable.Contains(addr) {
|
if lh.myVpnNetworksTable.Contains(addr) {
|
||||||
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range",
|
||||||
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
"addr", rawAddr,
|
||||||
|
"entry", i+1,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
|
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
|
||||||
|
|
||||||
if !initial {
|
if !initial {
|
||||||
lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load())
|
lh.l.Info("lighthouse.interval changed",
|
||||||
|
"interval", lh.interval.Load(),
|
||||||
|
)
|
||||||
|
|
||||||
if lh.updateCancel != nil {
|
if lh.updateCancel != nil {
|
||||||
// May not always have a running routine
|
// May not always have a running routine
|
||||||
@@ -316,6 +323,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
if !initial {
|
if !initial {
|
||||||
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
||||||
lh.l.Info("lighthouse.hosts has changed")
|
lh.l.Info("lighthouse.hosts has changed")
|
||||||
|
lh.TriggerUpdate()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
for _, v := range c.GetStringSlice("relay.relays", nil) {
|
for _, v := range c.GetStringSlice("relay.relays", nil) {
|
||||||
configRIP, err := netip.ParseAddr(v)
|
configRIP, err := netip.ParseAddr(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed")
|
lh.l.Warn("Parse relay from config failed",
|
||||||
|
"relay", v,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
lh.l.WithField("relay", v).Info("Read relay from config")
|
lh.l.Info("Read relay from config", "relay", v)
|
||||||
relaysForMe = append(relaysForMe, configRIP)
|
relaysForMe = append(relaysForMe, configRIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -360,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not",
|
||||||
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
"vpnAddr", addr,
|
||||||
|
"networks", lh.myVpnNetworks,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
out[i] = addr
|
out[i] = addr
|
||||||
}
|
}
|
||||||
@@ -432,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work",
|
||||||
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
"vpnAddr", vpnAddr,
|
||||||
|
"networks", lh.myVpnNetworks,
|
||||||
|
"entry", i+1,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
vals, ok := v.([]any)
|
vals, ok := v.([]any)
|
||||||
@@ -534,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
|||||||
lh.Lock()
|
lh.Lock()
|
||||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||||
if ok {
|
if ok {
|
||||||
|
debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug)
|
||||||
for _, addr := range allVpnAddrs {
|
for _, addr := range allVpnAddrs {
|
||||||
srm := lh.addrMap[addr]
|
srm := lh.addrMap[addr]
|
||||||
if srm == rm {
|
if srm == rm {
|
||||||
delete(lh.addrMap, addr)
|
delete(lh.addrMap, addr)
|
||||||
if lh.l.Level >= logrus.DebugLevel {
|
if debugEnabled {
|
||||||
lh.l.Debugf("deleting %s from lighthouse.", addr)
|
lh.l.Debug("deleting from lighthouse", "vpnAddr", addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -656,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
|||||||
|
|
||||||
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
||||||
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||||
Trace("remoteAllowList.Allow")
|
"vpnAddrs", vpnAddrs,
|
||||||
|
"udpAddr", to,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
return false
|
return false
|
||||||
@@ -675,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
|||||||
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
|
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
|
||||||
udpAddr := protoV4AddrPortToNetAddrPort(to)
|
udpAddr := protoV4AddrPortToNetAddrPort(to)
|
||||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
|
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||||
Trace("remoteAllowList.Allow")
|
"vpnAddr", vpnAddr,
|
||||||
|
"udpAddr", udpAddr,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -695,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
|
|||||||
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
|
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
|
||||||
udpAddr := protoV6AddrPortToNetAddrPort(to)
|
udpAddr := protoV6AddrPortToNetAddrPort(to)
|
||||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
|
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||||
Trace("remoteAllowList.Allow")
|
"vpnAddr", vpnAddr,
|
||||||
|
"udpAddr", udpAddr,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -713,21 +739,14 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
|||||||
|
|
||||||
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
||||||
l := lh.GetLighthouses()
|
l := lh.GetLighthouses()
|
||||||
for i := range l {
|
return slices.Contains(l, vpnAddr)
|
||||||
if l[i] == vpnAddr {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
||||||
l := lh.GetLighthouses()
|
l := lh.GetLighthouses()
|
||||||
for i := range vpnAddrs {
|
for i := range vpnAddrs {
|
||||||
for j := range l {
|
if slices.Contains(l, vpnAddrs[i]) {
|
||||||
if l[j] == vpnAddrs[i] {
|
return true
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -779,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
|
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if !addr.Is4() {
|
if !addr.Is4() {
|
||||||
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol",
|
||||||
Error("Can't query lighthouse for v6 address using a v1 protocol")
|
"queryVpnAddr", addr,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -791,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
|
|
||||||
v1Query, err = msg.Marshal()
|
v1Query, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("queryVpnAddr", addr).
|
lh.l.Error("Failed to marshal lighthouse v1 query payload",
|
||||||
WithField("lighthouseAddr", lhVpnAddr).
|
"error", err,
|
||||||
Error("Failed to marshal lighthouse v1 query payload")
|
"queryVpnAddr", addr,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -808,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
|
|
||||||
v2Query, err = msg.Marshal()
|
v2Query, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("queryVpnAddr", addr).
|
lh.l.Error("Failed to marshal lighthouse v2 query payload",
|
||||||
WithField("lighthouseAddr", lhVpnAddr).
|
"error", err,
|
||||||
Error("Failed to marshal lighthouse v2 query payload")
|
"queryVpnAddr", addr,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -819,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
queried++
|
queried++
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v)
|
lh.l.Debug("unsupported protocol version",
|
||||||
|
"op", "query",
|
||||||
|
"queryVpnAddr", addr,
|
||||||
|
"version", v,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -848,11 +877,24 @@ func (lh *LightHouse) StartUpdateWorker() {
|
|||||||
return
|
return
|
||||||
case <-clockSource.C:
|
case <-clockSource.C:
|
||||||
continue
|
continue
|
||||||
|
case <-lh.updateTrigger:
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TriggerUpdate requests an immediate lighthouse update. This is a non-blocking
|
||||||
|
// operation intended to be called after a handshake completes with a lighthouse,
|
||||||
|
// so the lighthouse has our current addresses without waiting for the next
|
||||||
|
// periodic update.
|
||||||
|
func (lh *LightHouse) TriggerUpdate() {
|
||||||
|
select {
|
||||||
|
case lh.updateTrigger <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) SendUpdate() {
|
func (lh *LightHouse) SendUpdate() {
|
||||||
var v4 []*V4AddrPort
|
var v4 []*V4AddrPort
|
||||||
var v6 []*V6AddrPort
|
var v6 []*V6AddrPort
|
||||||
@@ -898,8 +940,9 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if v1Update == nil {
|
if v1Update == nil {
|
||||||
if !lh.myVpnNetworks[0].Addr().Is4() {
|
if !lh.myVpnNetworks[0].Addr().Is4() {
|
||||||
lh.l.WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address",
|
||||||
Warn("cannot update lighthouse using v1 protocol without an IPv4 address")
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var relays []uint32
|
var relays []uint32
|
||||||
@@ -923,8 +966,10 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
|
|
||||||
v1Update, err = msg.Marshal()
|
v1Update, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Error("Error while marshaling for lighthouse v1 update",
|
||||||
Error("Error while marshaling for lighthouse v1 update")
|
"error", err,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -950,8 +995,10 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
|
|
||||||
v2Update, err = msg.Marshal()
|
v2Update, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Error("Error while marshaling for lighthouse v2 update",
|
||||||
Error("Error while marshaling for lighthouse v2 update")
|
"error", err,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -960,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
updated++
|
updated++
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v)
|
lh.l.Debug("unsupported protocol version",
|
||||||
|
"op", "update",
|
||||||
|
"version", v,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -974,7 +1024,7 @@ type LightHouseHandler struct {
|
|||||||
out []byte
|
out []byte
|
||||||
pb []byte
|
pb []byte
|
||||||
meta *NebulaMeta
|
meta *NebulaMeta
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
|
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
|
||||||
@@ -1023,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
|||||||
n := lhh.resetMeta()
|
n := lhh.resetMeta()
|
||||||
err := n.Unmarshal(p)
|
err := n.Unmarshal(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.Error("Failed to unmarshal lighthouse packet",
|
||||||
Error("Failed to unmarshal lighthouse packet")
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"udpAddr", rAddr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.Details == nil {
|
if n.Details == nil {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.Error("Invalid lighthouse update",
|
||||||
Error("Invalid lighthouse update")
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"udpAddr", rAddr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1058,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
|||||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
||||||
// Exit if we don't answer queries
|
// Exit if we don't answer queries
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugln("I don't answer queries, but received from: ", addr)
|
lhh.l.Debug("I don't answer queries, but received one", "from", addr)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
lhh.l.Debug("Dropping malformed HostQuery",
|
||||||
Debugln("Dropping malformed HostQuery")
|
"from", fromVpnAddrs,
|
||||||
|
"details", n.Details,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
lhh.l.Debug("invalid vpn addr for v1 handleHostQuery",
|
||||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"queryVpnAddr", queryVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1101,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
lhh.l.Error("Failed to marshal lighthouse host query reply",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1129,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
if ok {
|
if ok {
|
||||||
whereToPunch = newDest
|
whereToPunch = newDest
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
lhh.l.Debug("unable to punch to host, no addresses in common",
|
||||||
|
"to", crt.Networks(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1156,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for")
|
lhh.l.Error("Failed to marshal lighthouse host was queried for",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1198,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
|||||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
lhh.l.Debug("unsupported protocol version",
|
||||||
|
"op", "coalesceAnswers",
|
||||||
|
"version", v,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1212,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
|
|
||||||
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
lhh.l.Error("dropping malformed HostQueryReply",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1238,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1262,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
|
|
||||||
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
||||||
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
lhh.l.Debug("Host sent invalid update",
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"answer", detailsVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1285,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
switch useVersion {
|
switch useVersion {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
if !fromVpnAddrs[0].Is4() {
|
if !fromVpnAddrs[0].Is4() {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message",
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrB := fromVpnAddrs[0].As4()
|
vpnAddrB := fromVpnAddrs[0].As4()
|
||||||
@@ -1293,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
case cert.Version2:
|
case cert.Version2:
|
||||||
// do nothing, we want to send a blank message
|
// do nothing, we want to send a blank message
|
||||||
default:
|
default:
|
||||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
lhh.l.Error("invalid protocol version", "useVersion", useVersion)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ln, err := n.MarshalTo(lhh.pb)
|
ln, err := n.MarshalTo(lhh.pb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack")
|
lhh.l.Error("Failed to marshal lighthouse host update ack",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1316,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
|
|
||||||
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
lhh.l.Debug("dropping invalid HostPunchNotification",
|
||||||
|
"details", n.Details,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1334,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
lhh.l.Debug("Punching",
|
||||||
|
"vpnPeer", vpnPeer,
|
||||||
|
"logVpnAddr", logVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1360,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
if lhh.lh.punchy.GetRespond() {
|
if lhh.lh.punchy.GetRespond() {
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
lhh.l.Debug("Sending a nebula test packet",
|
||||||
|
"vpnAddr", detailsVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -42,14 +41,14 @@ func Test_lhStaticMapping(t *testing.T) {
|
|||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
|
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
|
||||||
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
||||||
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
_, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
lh2 := "10.128.0.3"
|
lh2 := "10.128.0.3"
|
||||||
c = config.NewC(l)
|
c = config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
|
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
|
||||||
c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
|
||||||
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
_, err = NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
|
||||||
require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
lh.ifce = &mockEncWriter{}
|
lh.ifce = &mockEncWriter{}
|
||||||
|
|
||||||
@@ -99,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(b.Context(), l, c, cs, nil, nil)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
|
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
|
||||||
@@ -202,7 +201,7 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
}
|
}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
|
||||||
lh.ifce = &mockEncWriter{}
|
lh.ifce = &mockEncWriter{}
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
lhh := lh.NewRequestHandler()
|
lhh := lh.NewRequestHandler()
|
||||||
@@ -288,7 +287,7 @@ func TestLighthouse_reload(t *testing.T) {
|
|||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
}
|
}
|
||||||
|
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nc := map[string]any{
|
nc := map[string]any{
|
||||||
@@ -523,7 +522,7 @@ func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
|
|||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
}
|
}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
lh.ifce = &mockEncWriter{}
|
lh.ifce = &mockEncWriter{}
|
||||||
|
|
||||||
@@ -589,7 +588,7 @@ func TestLighthouse_DeletesWork(t *testing.T) {
|
|||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
}
|
}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
lh.ifce = &mockEncWriter{}
|
lh.ifce = &mockEncWriter{}
|
||||||
|
|
||||||
|
|||||||
45
logger.go
45
logger.go
@@ -1,45 +0,0 @@
|
|||||||
package nebula
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func configLogger(l *logrus.Logger, c *config.C) error {
|
|
||||||
// set up our logging level
|
|
||||||
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
|
|
||||||
}
|
|
||||||
l.SetLevel(logLevel)
|
|
||||||
|
|
||||||
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
|
||||||
timestampFormat := c.GetString("logging.timestamp_format", "")
|
|
||||||
fullTimestamp := (timestampFormat != "")
|
|
||||||
if timestampFormat == "" {
|
|
||||||
timestampFormat = time.RFC3339
|
|
||||||
}
|
|
||||||
|
|
||||||
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
|
||||||
switch logFormat {
|
|
||||||
case "text":
|
|
||||||
l.Formatter = &logrus.TextFormatter{
|
|
||||||
TimestampFormat: timestampFormat,
|
|
||||||
FullTimestamp: fullTimestamp,
|
|
||||||
DisableTimestamp: disableTimestamp,
|
|
||||||
}
|
|
||||||
case "json":
|
|
||||||
l.Formatter = &logrus.JSONFormatter{
|
|
||||||
TimestampFormat: timestampFormat,
|
|
||||||
DisableTimestamp: disableTimestamp,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
233
logging/logger.go
Normal file
233
logging/logger.go
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
// Package logging wires the nebula runtime-reconfigurable slog handler used
|
||||||
|
// by nebula.Main and the nebula CLI binaries. Callers build a logger with
|
||||||
|
// NewLogger, then call ApplyConfig at startup and from a config reload
|
||||||
|
// callback to push logging.level, logging.format, and
|
||||||
|
// logging.disable_timestamp changes onto the logger without rebuilding it.
|
||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config is the subset of *config.C that ApplyConfig reads. Declaring it
|
||||||
|
// here keeps the logging package from depending on config directly, which
|
||||||
|
// would cycle through the shared test helpers (test.NewLogger imports
|
||||||
|
// logging, and config's tests import test). *config.C satisfies this
|
||||||
|
// interface structurally with no adapter.
|
||||||
|
type Config interface {
|
||||||
|
GetString(key, def string) string
|
||||||
|
GetBool(key string, def bool) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// LevelTrace is a custom slog level below Debug, used when logging.level is
|
||||||
|
// "trace". slog has no builtin trace level; the value is one step below
|
||||||
|
// slog.LevelDebug in slog's 4-point spacing.
|
||||||
|
const LevelTrace = slog.Level(-8)
|
||||||
|
|
||||||
|
// NewLogger returns a *slog.Logger whose level, format, and timestamp
|
||||||
|
// emission can be reconfigured at runtime via ApplyConfig and the SSH debug
|
||||||
|
// commands. The default configuration is info-level text output so log
|
||||||
|
// calls made before ApplyConfig runs still produce output. Timestamps
|
||||||
|
// follow slog's default RFC3339Nano format; set logging.disable_timestamp
|
||||||
|
// in config to suppress them.
|
||||||
|
//
|
||||||
|
// ApplyConfig and the SSH commands discover the reconfig surface via
|
||||||
|
// structural type-assertion on l.Handler(), so replacement implementations
|
||||||
|
// (tests, platform-specific sinks) need only implement the subset of
|
||||||
|
// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)}
|
||||||
|
// they care about. Callers that pass a plain *slog.Logger without these
|
||||||
|
// methods get a silent no-op; reconfiguration is always opt-in.
|
||||||
|
func NewLogger(w io.Writer) *slog.Logger {
|
||||||
|
return slog.New(NewHandler(w))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler builds the *Handler that NewLogger wraps. Exported for
|
||||||
|
// platform-specific sinks (notably cmd/nebula-service/logs_windows.go)
|
||||||
|
// that want to wrap the handler with extra behavior, such as tagging each
|
||||||
|
// record with its Event Log severity, while still benefiting from all the
|
||||||
|
// level / format / timestamp / WithAttrs machinery implemented here.
|
||||||
|
func NewHandler(w io.Writer) *Handler {
|
||||||
|
root := &handlerRoot{}
|
||||||
|
root.level.Set(slog.LevelInfo)
|
||||||
|
opts := &slog.HandlerOptions{Level: &root.level}
|
||||||
|
return &Handler{
|
||||||
|
root: root,
|
||||||
|
text: slog.NewTextHandler(w, opts),
|
||||||
|
json: slog.NewJSONHandler(w, opts),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlerRoot carries the reconfiguration state shared by every logger
|
||||||
|
// derived from a NewHandler call. All fields are consulted on the log
|
||||||
|
// path and updated lock-free.
|
||||||
|
type handlerRoot struct {
|
||||||
|
level slog.LevelVar
|
||||||
|
disableTimestamp atomic.Bool
|
||||||
|
// jsonMode picks which of the pre-derived inner handlers Handler.Handle
|
||||||
|
// dispatches to. Flipping it propagates instantly to every derived logger
|
||||||
|
// without rebuilding or chain-replaying anything.
|
||||||
|
jsonMode atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler is the slog.Handler returned by NewHandler. It holds two
|
||||||
|
// pre-derived slog handlers -- one text, one json -- both built from the
|
||||||
|
// same accumulated WithAttrs/WithGroup state. Handle picks which one to
|
||||||
|
// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes
|
||||||
|
// effect immediately across the whole process without having to rebuild
|
||||||
|
// any derived loggers.
|
||||||
|
type Handler struct {
|
||||||
|
root *handlerRoot
|
||||||
|
text slog.Handler
|
||||||
|
json slog.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) Enabled(_ context.Context, l slog.Level) bool {
|
||||||
|
return h.root.level.Level() <= l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) Handle(ctx context.Context, r slog.Record) error {
|
||||||
|
if h.root.disableTimestamp.Load() {
|
||||||
|
r.Time = time.Time{}
|
||||||
|
}
|
||||||
|
if h.root.jsonMode.Load() {
|
||||||
|
return h.json.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
return h.text.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||||
|
if len(attrs) == 0 {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return &Handler{
|
||||||
|
root: h.root,
|
||||||
|
text: h.text.WithAttrs(attrs),
|
||||||
|
json: h.json.WithAttrs(attrs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) WithGroup(name string) slog.Handler {
|
||||||
|
if name == "" {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return &Handler{
|
||||||
|
root: h.root,
|
||||||
|
text: h.text.WithGroup(name),
|
||||||
|
json: h.json.WithGroup(name),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLevel updates the effective log level. Propagates to every derived
|
||||||
|
// logger via the shared LevelVar.
|
||||||
|
func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) }
|
||||||
|
|
||||||
|
// GetLevel reports the current log level.
|
||||||
|
func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() }
|
||||||
|
|
||||||
|
// SetFormat flips the output format atomically. Valid formats are "text"
|
||||||
|
// and "json". Every derived logger sees the new format on its next Handle
|
||||||
|
// call; no rebuild or registration is required.
|
||||||
|
func (h *Handler) SetFormat(format string) error {
|
||||||
|
switch format {
|
||||||
|
case "text":
|
||||||
|
h.root.jsonMode.Store(false)
|
||||||
|
case "json":
|
||||||
|
h.root.jsonMode.Store(true)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFormat reports the currently selected format name.
|
||||||
|
func (h *Handler) GetFormat() string {
|
||||||
|
if h.root.jsonMode.Load() {
|
||||||
|
return "json"
|
||||||
|
}
|
||||||
|
return "text"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableTimestamp toggles whether Handle zeroes r.Time before
|
||||||
|
// dispatching (slog's builtin text/json handlers skip emitting the time
|
||||||
|
// attribute on a zero time).
|
||||||
|
func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) }
|
||||||
|
|
||||||
|
// ApplyConfig reads logging.level, logging.format, and (optionally)
|
||||||
|
// logging.disable_timestamp from c and applies them to l. The reconfig
|
||||||
|
// surface is discovered via structural type-assertion on l.Handler(), so
|
||||||
|
// foreign handlers silently opt out of whichever capabilities they do not
|
||||||
|
// implement.
|
||||||
|
//
|
||||||
|
// nebula.Main does NOT call this function on your behalf; callers that want
|
||||||
|
// config-driven log level / format / timestamp updates invoke it at
|
||||||
|
// startup and register it as a reload callback themselves. This keeps the
|
||||||
|
// library from mutating an embedder's logger without their say-so.
|
||||||
|
func ApplyConfig(l *slog.Logger, c Config) error {
|
||||||
|
h := l.Handler()
|
||||||
|
|
||||||
|
lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok {
|
||||||
|
ls.SetLevel(lvl)
|
||||||
|
}
|
||||||
|
|
||||||
|
format := strings.ToLower(c.GetString("logging.format", "text"))
|
||||||
|
if fs, ok := h.(interface{ SetFormat(string) error }); ok {
|
||||||
|
if err := fs.SetFormat(format); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok {
|
||||||
|
ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseLevel converts a config-string level name ("trace", "debug", "info",
|
||||||
|
// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and
|
||||||
|
// "panic" are accepted for backwards compatibility with pre-slog configs
|
||||||
|
// and both map to slog.LevelError.
|
||||||
|
func ParseLevel(s string) (slog.Level, error) {
|
||||||
|
switch s {
|
||||||
|
case "trace":
|
||||||
|
return LevelTrace, nil
|
||||||
|
case "debug":
|
||||||
|
return slog.LevelDebug, nil
|
||||||
|
case "info":
|
||||||
|
return slog.LevelInfo, nil
|
||||||
|
case "warn", "warning":
|
||||||
|
return slog.LevelWarn, nil
|
||||||
|
case "error":
|
||||||
|
return slog.LevelError, nil
|
||||||
|
case "fatal", "panic":
|
||||||
|
return slog.LevelError, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("not a valid logging level: %q", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LevelName returns a human-readable name for a slog.Level matching the
|
||||||
|
// strings accepted by ParseLevel.
|
||||||
|
func LevelName(l slog.Level) string {
|
||||||
|
switch {
|
||||||
|
case l <= LevelTrace:
|
||||||
|
return "trace"
|
||||||
|
case l <= slog.LevelDebug:
|
||||||
|
return "debug"
|
||||||
|
case l <= slog.LevelInfo:
|
||||||
|
return "info"
|
||||||
|
case l <= slog.LevelWarn:
|
||||||
|
return "warn"
|
||||||
|
default:
|
||||||
|
return "error"
|
||||||
|
}
|
||||||
|
}
|
||||||
90
logging/logger_bench_test.go
Normal file
90
logging/logger_bench_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkLogger_* compare the handler returned by NewLogger against a
|
||||||
|
// stock slog text handler. The key thing we care about is the per-log
|
||||||
|
// cost on a logger that has been derived via .With(), because that is the
|
||||||
|
// shape subsystems store on their structs (HostInfo.logger(),
|
||||||
|
// lh.l.With("subsystem", ...), etc.) and call from hot paths.
|
||||||
|
|
||||||
|
func BenchmarkLogger_Stock_RootInfo(b *testing.B) {
|
||||||
|
l := slog.New(slog.DiscardHandler)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Nebula_RootInfo(b *testing.B) {
|
||||||
|
l := NewLogger(io.Discard)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) {
|
||||||
|
l := slog.New(slog.DiscardHandler).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) {
|
||||||
|
l := NewLogger(io.Discard).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gated-off-path benchmarks: mimic the typical hot-path shape
|
||||||
|
// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below
|
||||||
|
// the active level. This is the dominant pattern in inside.go/outside.go and
|
||||||
|
// what we pay on every packet.
|
||||||
|
func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) {
|
||||||
|
l := slog.New(slog.DiscardHandler).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
ctx := context.Background()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if l.Enabled(ctx, slog.LevelDebug) {
|
||||||
|
l.Debug("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) {
|
||||||
|
l := NewLogger(io.Discard).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
ctx := context.Background()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if l.Enabled(ctx, slog.LevelDebug) {
|
||||||
|
l.Debug("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
93
main.go
93
main.go
@@ -3,13 +3,13 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|
||||||
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
buildVersion = moduleVersion()
|
buildVersion = moduleVersion()
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logger
|
|
||||||
l.Formatter = &logrus.TextFormatter{
|
|
||||||
FullTimestamp: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print the config if in test, the exit comes later
|
// Print the config if in test, the exit comes later
|
||||||
if configTest {
|
if configTest {
|
||||||
b, err := yaml.Marshal(c.Settings)
|
b, err := yaml.Marshal(c.Settings)
|
||||||
@@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Print the final config
|
// Print the final config
|
||||||
l.Println(string(b))
|
l.Info(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := configLogger(l, c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
|
||||||
err := configLogger(l, c)
|
|
||||||
if err != nil {
|
|
||||||
l.WithError(err).Error("Failed to configure the logger")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
pki, err := NewPKIFromConfig(l, c)
|
pki, err := NewPKIFromConfig(l, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
||||||
@@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
||||||
}
|
}
|
||||||
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
|
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
||||||
}
|
}
|
||||||
@@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err)
|
||||||
sshStart = nil
|
sshStart = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -99,19 +82,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
routines = 1
|
routines = 1
|
||||||
}
|
}
|
||||||
if routines > 1 {
|
if routines > 1 {
|
||||||
l.WithField("routines", routines).Info("Using multiple routines")
|
l.Info("Using multiple routines", "routines", routines)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// deprecated and undocumented
|
// deprecated and undocumented
|
||||||
tunQueues := c.GetInt("tun.routines", 1)
|
tunQueues := c.GetInt("tun.routines", 1)
|
||||||
udpQueues := c.GetInt("listen.routines", 1)
|
udpQueues := c.GetInt("listen.routines", 1)
|
||||||
if tunQueues > udpQueues {
|
routines = max(tunQueues, udpQueues)
|
||||||
routines = tunQueues
|
|
||||||
} else {
|
|
||||||
routines = udpQueues
|
|
||||||
}
|
|
||||||
if routines != 1 {
|
if routines != 1 {
|
||||||
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
|
l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
conntrackCacheTimeout = 1 * time.Second
|
conntrackCacheTimeout = 1 * time.Second
|
||||||
}
|
}
|
||||||
if conntrackCacheTimeout > 0 {
|
if conntrackCacheTimeout > 0 {
|
||||||
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
var tun overlay.Device
|
var tun overlay.Device
|
||||||
@@ -170,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
@@ -205,27 +184,19 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
messageMetrics = newMessageMetricsOnlyRecvError()
|
messageMetrics = newMessageMetricsOnlyRecvError()
|
||||||
}
|
}
|
||||||
|
|
||||||
useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false)
|
|
||||||
|
|
||||||
handshakeConfig := HandshakeConfig{
|
handshakeConfig := HandshakeConfig{
|
||||||
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
|
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
|
||||||
retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
|
retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
|
||||||
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
|
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
|
||||||
useRelays: useRelays,
|
|
||||||
|
|
||||||
messageMetrics: messageMetrics,
|
messageMetrics: messageMetrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||||
lightHouse.handshakeTrigger = handshakeManager.trigger
|
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||||
|
|
||||||
serveDns := false
|
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
|
||||||
if c.GetBool("lighthouse.serve_dns", false) {
|
if err != nil {
|
||||||
if c.GetBool("lighthouse.am_lighthouse", false) {
|
l.Warn("Failed to start DNS responder", "error", err)
|
||||||
serveDns = true
|
|
||||||
} else {
|
|
||||||
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
@@ -234,7 +205,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
Outside: udpConns[0],
|
Outside: udpConns[0],
|
||||||
pki: pki,
|
pki: pki,
|
||||||
Firewall: fw,
|
Firewall: fw,
|
||||||
ServeDns: serveDns,
|
DnsServer: ds,
|
||||||
HandshakeManager: handshakeManager,
|
HandshakeManager: handshakeManager,
|
||||||
connectionManager: connManager,
|
connectionManager: connManager,
|
||||||
lightHouse: lightHouse,
|
lightHouse: lightHouse,
|
||||||
@@ -304,7 +275,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
go handshakeManager.Run(ctx)
|
go handshakeManager.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
statsStart, err := startStats(l, c, buildVersion, configTest)
|
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
|
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
|
||||||
}
|
}
|
||||||
@@ -317,23 +288,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
|
|
||||||
attachCommands(l, c, ssh, ifce)
|
attachCommands(l, c, ssh, ifce)
|
||||||
|
|
||||||
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
|
||||||
var dnsStart func()
|
|
||||||
if lightHouse.amLighthouse && serveDns {
|
|
||||||
l.Debugln("Starting dns server")
|
|
||||||
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
ifce,
|
state: StateReady,
|
||||||
l,
|
f: ifce,
|
||||||
ctx,
|
l: l,
|
||||||
cancel,
|
ctx: ctx,
|
||||||
sshStart,
|
cancel: cancel,
|
||||||
statsStart,
|
sshStart: sshStart,
|
||||||
dnsStart,
|
statsStart: stats.Start,
|
||||||
lightHouse.StartUpdateWorker,
|
dnsStart: ds.Start,
|
||||||
connManager.Start,
|
lighthouseStart: lightHouse.StartUpdateWorker,
|
||||||
|
connectionManagerStart: connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ type MessageMetrics struct {
|
|||||||
|
|
||||||
rxUnknown metrics.Counter
|
rxUnknown metrics.Counter
|
||||||
txUnknown metrics.Counter
|
txUnknown metrics.Counter
|
||||||
|
|
||||||
|
rxInvalid metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
|
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
|
||||||
@@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func (m *MessageMetrics) RxInvalid(i int64) {
|
||||||
|
if m != nil && m.rxInvalid != nil {
|
||||||
|
m.rxInvalid.Inc(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newMessageMetrics() *MessageMetrics {
|
func newMessageMetrics() *MessageMetrics {
|
||||||
gen := func(t string) [][]metrics.Counter {
|
gen := func(t string) [][]metrics.Counter {
|
||||||
@@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics {
|
|||||||
|
|
||||||
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
|
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
|
||||||
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
|
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
|
||||||
|
rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
1072
nebula.pb.go
1072
nebula.pb.go
File diff suppressed because it is too large
Load Diff
26
nebula.proto
26
nebula.proto
@@ -60,29 +60,9 @@ message NebulaPing {
|
|||||||
uint64 Time = 2;
|
uint64 Time = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message NebulaHandshake {
|
// NebulaHandshake / NebulaHandshakeDetails moved to
|
||||||
NebulaHandshakeDetails Details = 1;
|
// handshake/handshake.proto. The handshake package speaks that wire format
|
||||||
bytes Hmac = 2;
|
// directly via a hand-written encoder/decoder.
|
||||||
}
|
|
||||||
|
|
||||||
message MultiPortDetails {
|
|
||||||
bool RxSupported = 1;
|
|
||||||
bool TxSupported = 2;
|
|
||||||
uint32 BasePort = 3;
|
|
||||||
uint32 TotalPorts = 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
message NebulaHandshakeDetails {
|
|
||||||
bytes Cert = 1;
|
|
||||||
uint32 InitiatorIndex = 2;
|
|
||||||
uint32 ResponderIndex = 3;
|
|
||||||
uint64 Cookie = 4;
|
|
||||||
uint64 Time = 5;
|
|
||||||
uint32 CertVersion = 8;
|
|
||||||
|
|
||||||
MultiPortDetails InitiatorMultiPort = 6;
|
|
||||||
MultiPortDetails ResponderMultiPort = 7;
|
|
||||||
}
|
|
||||||
|
|
||||||
message NebulaControl {
|
message NebulaControl {
|
||||||
enum MessageType {
|
enum MessageType {
|
||||||
|
|||||||
14
noise.go
14
noise.go
@@ -15,14 +15,12 @@ type endianness interface {
|
|||||||
var noiseEndianness endianness = binary.BigEndian
|
var noiseEndianness endianness = binary.BigEndian
|
||||||
|
|
||||||
type NebulaCipherState struct {
|
type NebulaCipherState struct {
|
||||||
c noise.Cipher
|
c cipher.AEAD
|
||||||
//k [32]byte
|
|
||||||
//n uint64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
|
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
|
||||||
return &NebulaCipherState{c: s.Cipher()}
|
x := s.Cipher()
|
||||||
|
return &NebulaCipherState{c: x.(cipher.AEAD)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncryptDanger encrypts and authenticates a given payload.
|
// EncryptDanger encrypts and authenticates a given payload.
|
||||||
@@ -46,7 +44,7 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n
|
|||||||
nb[2] = 0
|
nb[2] = 0
|
||||||
nb[3] = 0
|
nb[3] = 0
|
||||||
noiseEndianness.PutUint64(nb[4:], n)
|
noiseEndianness.PutUint64(nb[4:], n)
|
||||||
out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad)
|
out = s.c.Seal(out, nb, plaintext, ad)
|
||||||
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
|
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
|
||||||
return out, nil
|
return out, nil
|
||||||
} else {
|
} else {
|
||||||
@@ -61,7 +59,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64,
|
|||||||
nb[2] = 0
|
nb[2] = 0
|
||||||
nb[3] = 0
|
nb[3] = 0
|
||||||
noiseEndianness.PutUint64(nb[4:], n)
|
noiseEndianness.PutUint64(nb[4:], n)
|
||||||
return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad)
|
return s.c.Open(out, nb, ciphertext, ad)
|
||||||
} else {
|
} else {
|
||||||
return []byte{}, nil
|
return []byte{}, nil
|
||||||
}
|
}
|
||||||
@@ -69,7 +67,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64,
|
|||||||
|
|
||||||
func (s *NebulaCipherState) Overhead() int {
|
func (s *NebulaCipherState) Overhead() int {
|
||||||
if s != nil {
|
if s != nil {
|
||||||
return s.c.(cipher.AEAD).Overhead()
|
return s.c.Overhead()
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build !boringcrypto
|
//go:build !boringcrypto
|
||||||
// +build !boringcrypto
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
523
outside.go
523
outside.go
@@ -1,15 +1,16 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
@@ -19,208 +20,239 @@ const (
|
|||||||
minFwPacketLen = 4
|
minFwPacketLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrOutOfWindow = errors.New("out of window packet")
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := h.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
|
// TODO: record metrics for rx holepunch/punchy packets?
|
||||||
if len(packet) > 1 {
|
if len(packet) > 1 {
|
||||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
|
f.messageMetrics.RxInvalid(1)
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
f.l.Debug("Error while parsing inbound packet",
|
||||||
|
"from", via,
|
||||||
|
"error", err,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.Version != header.Version {
|
||||||
|
f.messageMetrics.RxInvalid(1)
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
f.l.Debug("Unexpected header version received", "from", via)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check before processing to see if this is a expected type/subtype
|
||||||
|
if !h.IsValidSubType() {
|
||||||
|
f.messageMetrics.RxInvalid(1)
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
f.l.Debug("Unexpected packet received", "from", via)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
|
||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
|
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
f.messageMetrics.RxInvalid(1)
|
||||||
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
f.l.Debug("Refusing to process double encrypted packet", "from", via)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// don't keep Rx metrics for message type, since you can see those in the tun metrics
|
||||||
|
if h.Type != header.Message {
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unencrypted packets
|
||||||
|
switch h.Type {
|
||||||
|
case header.Handshake:
|
||||||
|
f.handshakeManager.HandleIncoming(via, packet, h)
|
||||||
|
return
|
||||||
|
|
||||||
|
case header.RecvError:
|
||||||
|
f.handleRecvError(via.UdpAddr, h)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relay packets are special
|
||||||
|
isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay)
|
||||||
|
|
||||||
var hostinfo *HostInfo
|
var hostinfo *HostInfo
|
||||||
// verify if we've seen this index before, otherwise respond to the handshake initiation
|
if isMessageRelay {
|
||||||
if h.Type == header.Message && h.Subtype == header.MessageRelay {
|
|
||||||
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
||||||
} else {
|
} else {
|
||||||
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ci *ConnectionState
|
// At this point we should have a valid existing tunnel, verify and send
|
||||||
if hostinfo != nil {
|
// recvError if necessary
|
||||||
ci = hostinfo.ConnectionState
|
if hostinfo == nil || hostinfo.ConnectionState == nil {
|
||||||
|
if !via.IsRelayed {
|
||||||
|
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// All remaining packets are encrypted
|
||||||
|
ci := hostinfo.ConnectionState
|
||||||
|
if !ci.window.Check(f.l, h.MessageCounter) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relay packets are special
|
||||||
|
if isMessageRelay {
|
||||||
|
f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
|
if err != nil {
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
hostinfo.logger(f.l).Debug("Failed to decrypt packet",
|
||||||
|
"error", err,
|
||||||
|
"from", via,
|
||||||
|
"header", h,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Roam before we respond
|
||||||
|
f.handleHostRoaming(hostinfo, via)
|
||||||
|
f.connectionManager.In(hostinfo)
|
||||||
|
|
||||||
switch h.Type {
|
switch h.Type {
|
||||||
case header.Message:
|
case header.Message:
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h.Subtype {
|
switch h.Subtype {
|
||||||
case header.MessageNone:
|
case header.MessageNone:
|
||||||
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
|
f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache)
|
||||||
return
|
default:
|
||||||
}
|
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h)
|
||||||
case header.MessageRelay:
|
return
|
||||||
// The entire body is sent as AD, not encrypted.
|
|
||||||
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
|
|
||||||
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
|
|
||||||
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
|
|
||||||
// which will gracefully fail in the DecryptDanger call.
|
|
||||||
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
|
|
||||||
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Successfully validated the thing. Get rid of the Relay header.
|
|
||||||
signedPayload = signedPayload[header.Len:]
|
|
||||||
// Pull the Roaming parts up here, and return in all call paths.
|
|
||||||
f.handleHostRoaming(hostinfo, via)
|
|
||||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
|
||||||
|
|
||||||
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
|
||||||
if !ok {
|
|
||||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
|
||||||
// its internal mapping. This should never happen.
|
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch relay.Type {
|
|
||||||
case TerminalType:
|
|
||||||
// If I am the target of this relay, process the unwrapped packet
|
|
||||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
|
||||||
via = ViaSender{
|
|
||||||
UdpAddr: via.UdpAddr,
|
|
||||||
relayHI: hostinfo,
|
|
||||||
remoteIdx: relay.RemoteIndex,
|
|
||||||
relay: relay,
|
|
||||||
IsRelayed: true,
|
|
||||||
}
|
|
||||||
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
|
||||||
return
|
|
||||||
case ForwardingType:
|
|
||||||
// Find the target HostInfo relay object
|
|
||||||
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If that relay is Established, forward the payload through it
|
|
||||||
if targetRelay.State == Established {
|
|
||||||
switch targetRelay.Type {
|
|
||||||
case ForwardingType:
|
|
||||||
// Forward this packet through the relay tunnel
|
|
||||||
// Find the target HostInfo
|
|
||||||
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
|
|
||||||
return
|
|
||||||
case TerminalType:
|
|
||||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case header.LightHouse:
|
case header.LightHouse:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
|
||||||
WithField("packet", packet).
|
|
||||||
Error("Failed to decrypt lighthouse packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
//TODO: assert via is not relayed
|
//TODO: assert via is not relayed
|
||||||
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
|
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
case header.Test:
|
case header.Test:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
switch h.Subtype {
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
case header.TestReply:
|
||||||
|
// No-op, useful for the Roaming and connectionManager side-effects above
|
||||||
|
case header.TestRequest:
|
||||||
|
f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out)
|
||||||
|
default:
|
||||||
|
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
|
||||||
WithField("packet", packet).
|
|
||||||
Error("Failed to decrypt test packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.Subtype == header.TestRequest {
|
|
||||||
// This testRequest might be from TryPromoteBest, so we should roam
|
|
||||||
// to the new IP address before responding
|
|
||||||
f.handleHostRoaming(hostinfo, via)
|
|
||||||
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
|
|
||||||
// are unauthenticated
|
|
||||||
|
|
||||||
case header.Handshake:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
f.handshakeManager.HandleIncoming(via, packet, h)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.RecvError:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
f.handleRecvError(via.UdpAddr, h)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.CloseTunnel:
|
case header.CloseTunnel:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("from", via).
|
|
||||||
Info("Close tunnel received, tearing down.")
|
|
||||||
|
|
||||||
f.closeTunnel(hostinfo)
|
f.closeTunnel(hostinfo)
|
||||||
return
|
|
||||||
|
|
||||||
case header.Control:
|
case header.Control:
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
f.relayManager.HandleControlMsg(hostinfo, out, f)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
|
||||||
WithField("packet", packet).
|
|
||||||
Error("Failed to decrypt Control packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h)
|
||||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
|
// The entire body is sent as AD, not encrypted.
|
||||||
|
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
|
||||||
|
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
|
||||||
|
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
|
||||||
|
// which will gracefully fail in the DecryptDanger call.
|
||||||
|
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
|
||||||
|
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
|
||||||
|
var err error
|
||||||
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Successfully validated the thing. Get rid of the Relay header.
|
||||||
|
signedPayload = signedPayload[header.Len:]
|
||||||
|
// Pull the Roaming parts up here, and return in all call paths.
|
||||||
|
f.handleHostRoaming(hostinfo, via)
|
||||||
|
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
||||||
|
f.connectionManager.In(hostinfo)
|
||||||
|
f.connectionManager.RelayUsed(h.RemoteIndex)
|
||||||
|
|
||||||
|
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
||||||
|
if !ok {
|
||||||
|
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
||||||
|
// its internal mapping. This should never happen.
|
||||||
|
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.handleHostRoaming(hostinfo, via)
|
switch relay.Type {
|
||||||
|
case TerminalType:
|
||||||
|
// If I am the target of this relay, process the unwrapped packet
|
||||||
|
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||||
|
via = ViaSender{
|
||||||
|
UdpAddr: via.UdpAddr,
|
||||||
|
relayHI: hostinfo,
|
||||||
|
remoteIdx: relay.RemoteIndex,
|
||||||
|
relay: relay,
|
||||||
|
IsRelayed: true,
|
||||||
|
}
|
||||||
|
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
|
case ForwardingType:
|
||||||
|
// Find the target HostInfo relay object
|
||||||
|
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
|
||||||
|
"relayTo", relay.PeerAddr,
|
||||||
|
"error", err,
|
||||||
|
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
// If that relay is Established, forward the payload through it
|
||||||
|
if targetRelay.State == Established {
|
||||||
|
switch targetRelay.Type {
|
||||||
|
case ForwardingType:
|
||||||
|
// Forward this packet through the relay tunnel
|
||||||
|
// Find the target HostInfo
|
||||||
|
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
|
||||||
|
case TerminalType:
|
||||||
|
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
hostinfo.logger(f.l).Info("Unexpected target relay state",
|
||||||
|
"relayTo", relay.PeerAddr,
|
||||||
|
"relayFrom", hostinfo.vpnAddrs[0],
|
||||||
|
"targetRelayState", targetRelay.State,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
||||||
@@ -249,20 +281,27 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
|
|||||||
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port())
|
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port())
|
||||||
}
|
}
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
||||||
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
return
|
hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr)
|
||||||
}
|
|
||||||
|
|
||||||
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
|
||||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||||
Info("Host roamed to new udp ip/port.")
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote",
|
||||||
|
"suppressSeconds", RoamingSuppressSeconds,
|
||||||
|
"udpAddr", hostinfo.remote,
|
||||||
|
"newAddr", via.UdpAddr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.",
|
||||||
|
"udpAddr", hostinfo.remote,
|
||||||
|
"newAddr", via.UdpAddr,
|
||||||
|
)
|
||||||
hostinfo.lastRoam = time.Now()
|
hostinfo.lastRoam = time.Now()
|
||||||
hostinfo.lastRoamRemote = hostinfo.remote
|
hostinfo.lastRoamRemote = hostinfo.remote
|
||||||
hostinfo.SetRemote(via.UdpAddr)
|
hostinfo.SetRemote(via.UdpAddr)
|
||||||
@@ -270,23 +309,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleEncrypted returns true if a packet should be processed, false otherwise
|
|
||||||
func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool {
|
|
||||||
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
|
|
||||||
if ci == nil {
|
|
||||||
if !via.IsRelayed {
|
|
||||||
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// If the window check fails, refuse to process the packet, but don't send a recv error
|
|
||||||
if !ci.window.Check(f.l, h.MessageCounter) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrPacketTooShort = errors.New("packet is too short")
|
ErrPacketTooShort = errors.New("packet is too short")
|
||||||
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
|
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
|
||||||
@@ -336,13 +358,29 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
proto := layers.IPProtocol(data[protoAt])
|
proto := layers.IPProtocol(data[protoAt])
|
||||||
|
|
||||||
switch proto {
|
switch proto {
|
||||||
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
case layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
||||||
fp.Protocol = uint8(proto)
|
fp.Protocol = uint8(proto)
|
||||||
fp.RemotePort = 0
|
fp.RemotePort = 0
|
||||||
fp.LocalPort = 0
|
fp.LocalPort = 0
|
||||||
fp.Fragment = false
|
fp.Fragment = false
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
case layers.IPProtocolICMPv6:
|
||||||
|
if dataLen < offset+6 {
|
||||||
|
return ErrIPv6PacketTooShort
|
||||||
|
}
|
||||||
|
fp.Protocol = uint8(proto)
|
||||||
|
fp.LocalPort = 0 //incoming vs outgoing doesn't matter for icmpv6
|
||||||
|
icmptype := data[offset+1]
|
||||||
|
switch icmptype {
|
||||||
|
case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply:
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier
|
||||||
|
default:
|
||||||
|
fp.RemotePort = 0
|
||||||
|
}
|
||||||
|
fp.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
case layers.IPProtocolTCP, layers.IPProtocolUDP:
|
case layers.IPProtocolTCP, layers.IPProtocolUDP:
|
||||||
if dataLen < offset+4 {
|
if dataLen < offset+4 {
|
||||||
return ErrIPv6PacketTooShort
|
return ErrIPv6PacketTooShort
|
||||||
@@ -432,34 +470,38 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
|
|
||||||
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
|
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
|
||||||
minLen := ihl
|
minLen := ihl
|
||||||
if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
|
if !fp.Fragment {
|
||||||
minLen += minFwPacketLen
|
if fp.Protocol == firewall.ProtoICMP {
|
||||||
|
minLen += minFwPacketLen + 2
|
||||||
|
} else {
|
||||||
|
minLen += minFwPacketLen
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(data) < minLen {
|
if len(data) < minLen {
|
||||||
return ErrIPv4InvalidHeaderLength
|
return ErrIPv4InvalidHeaderLength
|
||||||
}
|
}
|
||||||
|
|
||||||
// Firewall packets are locally oriented
|
if incoming { // Firewall packets are locally oriented
|
||||||
if incoming {
|
|
||||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
|
||||||
fp.RemotePort = 0
|
|
||||||
fp.LocalPort = 0
|
|
||||||
} else {
|
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
}
|
||||||
fp.RemotePort = 0
|
|
||||||
fp.LocalPort = 0
|
if fp.Fragment {
|
||||||
} else {
|
fp.RemotePort = 0
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
fp.LocalPort = 0
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
} else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
|
||||||
}
|
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
|
||||||
|
fp.LocalPort = 0 //code would be uint16(data[ihl+1])
|
||||||
|
} else if incoming {
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
|
||||||
|
} else {
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -473,34 +515,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
||||||
hostinfo.logger(f.l).WithField("header", h).
|
return nil, ErrOutOfWindow
|
||||||
Debugln("dropping out of window packet")
|
|
||||||
return nil, errors.New("out of window packet")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
var err error
|
err := newPacket(out, true, fwPacket)
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
|
||||||
return false
|
"error", err,
|
||||||
}
|
"packet", out,
|
||||||
|
)
|
||||||
err = newPacket(out, true, fwPacket)
|
return
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
|
||||||
Warnf("Error while validating inbound packet")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
|
||||||
Debugln("dropping out of window packet")
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
@@ -508,20 +536,19 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).Debug("dropping inbound packet",
|
||||||
WithField("reason", dropReason).
|
"fwPacket", fwPacket,
|
||||||
Debugln("dropping inbound packet")
|
"reason", dropReason,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
_, err = f.readers[q].Write(out)
|
_, err = f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.Error("Failed to write to tun", "error", err)
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
|
func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
|
||||||
@@ -535,35 +562,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
|
|||||||
|
|
||||||
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
||||||
_ = f.outside.WriteTo(b, endpoint)
|
_ = f.outside.WriteTo(b, endpoint)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("index", index).
|
f.l.Debug("Recv error sent",
|
||||||
WithField("udpAddr", endpoint).
|
"index", index,
|
||||||
Debug("Recv error sent")
|
"udpAddr", endpoint,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||||
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
|
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
|
||||||
f.l.WithField("index", h.RemoteIndex).
|
f.l.Debug("Recv error received, ignoring",
|
||||||
WithField("udpAddr", addr).
|
"index", h.RemoteIndex,
|
||||||
Debug("Recv error received, ignoring")
|
"udpAddr", addr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("index", h.RemoteIndex).
|
f.l.Debug("Recv error received",
|
||||||
WithField("udpAddr", addr).
|
"index", h.RemoteIndex,
|
||||||
Debug("Recv error received")
|
"udpAddr", addr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
|
f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
f.l.Info("Someone spoofing recv_errors?",
|
||||||
|
"addr", addr,
|
||||||
|
"hostinfoRemote", hostinfo.remote,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
// next layer, missing length byte
|
// next layer, missing length byte
|
||||||
err = newPacket(buffer.Bytes()[:49], true, p)
|
err = newPacket(buffer.Bytes()[:49], true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
err = nil
|
||||||
|
|
||||||
// A good ICMP packet
|
// A good ICMP packet
|
||||||
ip = layers.IPv6{
|
ip = layers.IPv6{
|
||||||
@@ -165,20 +166,26 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
DstIP: net.IPv6linklocalallnodes,
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
}
|
}
|
||||||
|
|
||||||
icmp := layers.ICMPv6{}
|
icmp := layers.ICMPv6{
|
||||||
|
TypeCode: layers.ICMPv6TypeEchoRequest,
|
||||||
buffer.Clear()
|
Checksum: 0x1234,
|
||||||
err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(buffer.Bytes(), true, p)
|
buffer.Clear()
|
||||||
require.NoError(t, err)
|
require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp))
|
||||||
|
require.Error(t, newPacket(buffer.Bytes(), true, p))
|
||||||
|
|
||||||
|
buffer.Clear()
|
||||||
|
echo := layers.ICMPv6Echo{
|
||||||
|
Identifier: 0xabcd,
|
||||||
|
SeqNumber: 1234,
|
||||||
|
}
|
||||||
|
require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp, &echo))
|
||||||
|
require.NoError(t, newPacket(buffer.Bytes(), true, p))
|
||||||
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
assert.Equal(t, uint16(0), p.RemotePort)
|
assert.Equal(t, uint16(0xabcd), p.RemotePort)
|
||||||
assert.Equal(t, uint16(0), p.LocalPort)
|
assert.Equal(t, uint16(0), p.LocalPort)
|
||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
@@ -574,7 +581,7 @@ func BenchmarkParseV6(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
evilBytes := buffer.Bytes()
|
evilBytes := buffer.Bytes()
|
||||||
for i := 0; i < 200; i++ {
|
for range 200 {
|
||||||
evilBytes = append(evilBytes, hopHeader...)
|
evilBytes = append(evilBytes, hopHeader...)
|
||||||
}
|
}
|
||||||
evilBytes = append(evilBytes, lastHopHeader...)
|
evilBytes = append(evilBytes, lastHopHeader...)
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
package test
|
// Package overlaytest provides fakes of overlay.Device for tests that do
|
||||||
|
// not want to touch a real tun device or route table.
|
||||||
|
package overlaytest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
@@ -8,6 +10,9 @@ import (
|
|||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NoopTun is an overlay.Device that silently discards every read and write.
|
||||||
|
// Useful in tests that need to construct a nebula Interface but do not
|
||||||
|
// exercise the datapath.
|
||||||
type NoopTun struct{}
|
type NoopTun struct{}
|
||||||
|
|
||||||
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
||||||
@@ -2,6 +2,7 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -9,7 +10,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
@@ -48,11 +48,14 @@ func (r Route) String() string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
||||||
routeTree := new(bart.Table[routing.Gateways])
|
routeTree := new(bart.Table[routing.Gateways])
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !allowMTU && r.MTU > 0 {
|
if !allowMTU && r.MTU > 0 {
|
||||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
l.Warn("route MTU is not supported on this platform",
|
||||||
|
"goos", runtime.GOOS,
|
||||||
|
"route", r,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
gateways := r.Via
|
gateways := r.Via
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) {
|
|||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip, err := netip.ParseAddr("1.0.0.2")
|
ip, err := netip.ParseAddr("1.0.0.2")
|
||||||
@@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
|
|||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 3)
|
assert.Len(t, routes, 3)
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip, err := netip.ParseAddr("192.168.86.1")
|
ip, err := netip.ParseAddr("192.168.86.1")
|
||||||
|
|||||||
@@ -2,20 +2,29 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultMTU = 1300
|
const DefaultMTU = 1300
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
type NameError struct {
|
||||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
Name string
|
||||||
|
Underlying error
|
||||||
|
}
|
||||||
|
|
||||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
func (e *NameError) Error() string {
|
||||||
|
return fmt.Sprintf("could not set tun device name: %s because %s", e.Name, e.Underlying)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: We may be able to remove routines
|
||||||
|
type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||||
|
|
||||||
|
func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
switch {
|
switch {
|
||||||
case c.GetBool("tun.disabled", false):
|
case c.GetBool("tun.disabled", false):
|
||||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||||
@@ -27,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
return newTunFromFd(c, l, *fd, vpnNetworks)
|
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -23,10 +23,10 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTun not supported in Android")
|
return nil, fmt.Errorf("newTun not supported in Android")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user