mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 12:57:38 +02:00
Merge remote-tracking branch 'origin/master' into fips140
This commit is contained in:
5
.github/workflows/release.yml
vendored
5
.github/workflows/release.yml
vendored
@@ -209,10 +209,11 @@ jobs:
|
|||||||
id: create_release
|
id: create_release
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
GITHUB_REF_NAME: ${{ github.ref_name }}
|
||||||
run: |
|
run: |
|
||||||
cd artifacts
|
cd artifacts
|
||||||
gh release create \
|
gh release create \
|
||||||
--verify-tag \
|
--verify-tag \
|
||||||
--title "Release ${{ github.ref_name }}" \
|
--title "Release ${GITHUB_REF_NAME}" \
|
||||||
"${{ github.ref_name }}" \
|
"${GITHUB_REF_NAME}" \
|
||||||
SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz
|
SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz
|
||||||
|
|||||||
26
.github/workflows/smoke-extra.yml
vendored
26
.github/workflows/smoke-extra.yml
vendored
@@ -18,6 +18,8 @@ jobs:
|
|||||||
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
|
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
|
||||||
name: Run extra smoke tests
|
name: Run extra smoke tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
VAGRANT_DEFAULT_PROVIDER: libvirt
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
@@ -30,8 +32,13 @@ jobs:
|
|||||||
- name: add hashicorp source
|
- name: add hashicorp source
|
||||||
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
|
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
|
||||||
|
|
||||||
- name: install vagrant
|
- name: install vagrant and libvirt
|
||||||
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
|
run: |
|
||||||
|
sudo apt-get update && sudo apt-get install -y vagrant libvirt-daemon-system libvirt-dev
|
||||||
|
sudo chmod 666 /dev/kvm
|
||||||
|
sudo usermod -aG libvirt $(whoami)
|
||||||
|
sudo chmod 666 /var/run/libvirt/libvirt-sock
|
||||||
|
vagrant plugin install vagrant-libvirt
|
||||||
|
|
||||||
- name: freebsd-amd64
|
- name: freebsd-amd64
|
||||||
run: make smoke-vagrant/freebsd-amd64
|
run: make smoke-vagrant/freebsd-amd64
|
||||||
@@ -42,10 +49,19 @@ jobs:
|
|||||||
- name: netbsd-amd64
|
- name: netbsd-amd64
|
||||||
run: make smoke-vagrant/netbsd-amd64
|
run: make smoke-vagrant/netbsd-amd64
|
||||||
|
|
||||||
- name: linux-386
|
|
||||||
run: make smoke-vagrant/linux-386
|
|
||||||
|
|
||||||
- name: linux-amd64-ipv6disable
|
- name: linux-amd64-ipv6disable
|
||||||
run: make smoke-vagrant/linux-amd64-ipv6disable
|
run: make smoke-vagrant/linux-amd64-ipv6disable
|
||||||
|
|
||||||
|
# linux-386 runs last because it requires disabling KVM to use VirtualBox,
|
||||||
|
# which prevents libvirt (used by the other tests) from working after this point.
|
||||||
|
- name: install virtualbox for i386 test
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y virtualbox
|
||||||
|
sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true
|
||||||
|
|
||||||
|
- name: linux-386
|
||||||
|
env:
|
||||||
|
VAGRANT_DEFAULT_PROVIDER: virtualbox
|
||||||
|
run: make smoke-vagrant/linux-386
|
||||||
|
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
|||||||
8
.github/workflows/smoke/build-relay.sh
vendored
8
.github/workflows/smoke/build-relay.sh
vendored
@@ -16,8 +16,10 @@ relay:
|
|||||||
am_relay: true
|
am_relay: true
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
export LIGHTHOUSES="192.168.100.1 172.17.0.2:4242"
|
# TEST-NET-3 placeholder IPs; smoke-relay.sh seds them to real container IPs.
|
||||||
export REMOTE_ALLOW_LIST='{"172.17.0.4/32": false, "172.17.0.5/32": false}'
|
# Mapping: .2 lighthouse1, .3 host2, .4 host3, .5 host4.
|
||||||
|
export LIGHTHOUSES="192.168.100.1 203.0.113.2:4242"
|
||||||
|
export REMOTE_ALLOW_LIST='{"203.0.113.4/32": false, "203.0.113.5/32": false}'
|
||||||
|
|
||||||
HOST="host2" ../genconfig.sh >host2.yml <<EOF
|
HOST="host2" ../genconfig.sh >host2.yml <<EOF
|
||||||
relay:
|
relay:
|
||||||
@@ -25,7 +27,7 @@ relay:
|
|||||||
- 192.168.100.1
|
- 192.168.100.1
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
export REMOTE_ALLOW_LIST='{"172.17.0.3/32": false}'
|
export REMOTE_ALLOW_LIST='{"203.0.113.3/32": false}'
|
||||||
|
|
||||||
HOST="host3" ../genconfig.sh >host3.yml
|
HOST="host3" ../genconfig.sh >host3.yml
|
||||||
|
|
||||||
|
|||||||
18
.github/workflows/smoke/build.sh
vendored
18
.github/workflows/smoke/build.sh
vendored
@@ -5,9 +5,15 @@ set -e -x
|
|||||||
rm -rf ./build
|
rm -rf ./build
|
||||||
mkdir ./build
|
mkdir ./build
|
||||||
|
|
||||||
# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
|
# Smoke containers run on a dedicated docker network whose subnet is allocated
|
||||||
# - We could make this better by launching the lighthouse first and then fetching what IP it is.
|
# at smoke time, not known at build time. Configs are written with TEST-NET-3
|
||||||
NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
|
# placeholder IPs (RFC 5737) and smoke.sh / smoke-vagrant.sh / smoke-relay.sh
|
||||||
|
# sed the real container IPs in before starting nebula.
|
||||||
|
#
|
||||||
|
# Placeholder mapping (last octet == fixed container slot):
|
||||||
|
# 203.0.113.2 -> lighthouse1, 203.0.113.3 -> host2,
|
||||||
|
# 203.0.113.4 -> host3, 203.0.113.5 -> host4.
|
||||||
|
LIGHTHOUSE_IP="203.0.113.2"
|
||||||
|
|
||||||
(
|
(
|
||||||
cd build
|
cd build
|
||||||
@@ -25,16 +31,16 @@ NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{
|
|||||||
../genconfig.sh >lighthouse1.yml
|
../genconfig.sh >lighthouse1.yml
|
||||||
|
|
||||||
HOST="host2" \
|
HOST="host2" \
|
||||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
|
||||||
../genconfig.sh >host2.yml
|
../genconfig.sh >host2.yml
|
||||||
|
|
||||||
HOST="host3" \
|
HOST="host3" \
|
||||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
|
||||||
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
../genconfig.sh >host3.yml
|
../genconfig.sh >host3.yml
|
||||||
|
|
||||||
HOST="host4" \
|
HOST="host4" \
|
||||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
LIGHTHOUSES="192.168.100.1 $LIGHTHOUSE_IP:4242" \
|
||||||
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
../genconfig.sh >host4.yml
|
../genconfig.sh >host4.yml
|
||||||
|
|
||||||
|
|||||||
57
.github/workflows/smoke/smoke-relay.sh
vendored
57
.github/workflows/smoke/smoke-relay.sh
vendored
@@ -6,6 +6,8 @@ set -o pipefail
|
|||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
|
|
||||||
|
NETWORK="nebula-smoke-relay"
|
||||||
|
|
||||||
cleanup() {
|
cleanup() {
|
||||||
echo
|
echo
|
||||||
echo " *** cleanup"
|
echo " *** cleanup"
|
||||||
@@ -16,22 +18,53 @@ cleanup() {
|
|||||||
then
|
then
|
||||||
docker kill lighthouse1 host2 host3 host4
|
docker kill lighthouse1 host2 host3 host4
|
||||||
fi
|
fi
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1
|
||||||
}
|
}
|
||||||
|
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
|
# Create a dedicated smoke network with an explicit subnet (required for --ip
|
||||||
docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
|
# below). Probe a short list of candidates so a locally-used range doesn't
|
||||||
docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
|
# fail the whole test — we only need one to be free.
|
||||||
docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
|
docker network rm "$NETWORK" >/dev/null 2>&1 || true
|
||||||
|
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
|
||||||
|
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
|
||||||
|
# .3 host2, .4 host3, .5 host4 — matches the placeholders in build-relay.sh.
|
||||||
|
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
|
||||||
|
PREFIX="${SUBNET%/*}"
|
||||||
|
PREFIX="${PREFIX%.*}"
|
||||||
|
LIGHTHOUSE_IP="$PREFIX.2"
|
||||||
|
HOST2_IP="$PREFIX.3"
|
||||||
|
HOST3_IP="$PREFIX.4"
|
||||||
|
HOST4_IP="$PREFIX.5"
|
||||||
|
|
||||||
|
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
|
||||||
|
for f in build/host2.yml build/host3.yml build/host4.yml; do
|
||||||
|
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
|
||||||
|
mv "$f.tmp" "$f"
|
||||||
|
done
|
||||||
|
|
||||||
|
docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
|
||||||
|
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" nebula:smoke-relay -config host2.yml -test
|
||||||
|
docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" nebula:smoke-relay -config host3.yml -test
|
||||||
|
docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" nebula:smoke-relay -config host4.yml -test
|
||||||
|
|
||||||
|
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
|
docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|
||||||
set +x
|
set +x
|
||||||
@@ -76,7 +109,13 @@ docker exec host4 sh -c 'kill 1'
|
|||||||
docker exec host3 sh -c 'kill 1'
|
docker exec host3 sh -c 'kill 1'
|
||||||
docker exec host2 sh -c 'kill 1'
|
docker exec host2 sh -c 'kill 1'
|
||||||
docker exec lighthouse1 sh -c 'kill 1'
|
docker exec lighthouse1 sh -c 'kill 1'
|
||||||
sleep 5
|
|
||||||
|
# Wait up to 30s for all backgrounded jobs to exit rather than relying on a
|
||||||
|
# fixed sleep.
|
||||||
|
for _ in $(seq 1 30); do
|
||||||
|
[ -z "$(jobs -r)" ] && break
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
if [ "$(jobs -r)" ]
|
if [ "$(jobs -r)" ]
|
||||||
then
|
then
|
||||||
|
|||||||
47
.github/workflows/smoke/smoke-vagrant.sh
vendored
47
.github/workflows/smoke/smoke-vagrant.sh
vendored
@@ -8,6 +8,8 @@ export VAGRANT_CWD="$PWD/vagrant-$1"
|
|||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
|
|
||||||
|
NETWORK="nebula-smoke"
|
||||||
|
|
||||||
cleanup() {
|
cleanup() {
|
||||||
echo
|
echo
|
||||||
echo " *** cleanup"
|
echo " *** cleanup"
|
||||||
@@ -19,21 +21,51 @@ cleanup() {
|
|||||||
docker kill lighthouse1 host2
|
docker kill lighthouse1 host2
|
||||||
fi
|
fi
|
||||||
vagrant destroy -f
|
vagrant destroy -f
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1
|
||||||
}
|
}
|
||||||
|
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
# Create a dedicated smoke network with an explicit subnet (required for --ip
|
||||||
|
# below). Probe a short list of candidates so a locally-used range doesn't
|
||||||
|
# fail the whole test — we only need one to be free.
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1 || true
|
||||||
|
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
|
||||||
|
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
|
||||||
|
# .3 host2 — matches the placeholders in build.sh.
|
||||||
|
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
|
||||||
|
PREFIX="${SUBNET%/*}"
|
||||||
|
PREFIX="${PREFIX%.*}"
|
||||||
|
LIGHTHOUSE_IP="$PREFIX.2"
|
||||||
|
HOST2_IP="$PREFIX.3"
|
||||||
|
|
||||||
|
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
|
||||||
|
# This must happen before `vagrant up` rsyncs build/ into the VM for host3.
|
||||||
|
for f in build/host2.yml build/host3.yml; do
|
||||||
|
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
|
||||||
|
mv "$f.tmp" "$f"
|
||||||
|
done
|
||||||
|
|
||||||
CONTAINER="nebula:${NAME:-smoke}"
|
CONTAINER="nebula:${NAME:-smoke}"
|
||||||
|
|
||||||
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
||||||
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
|
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test
|
||||||
|
|
||||||
vagrant up
|
vagrant up
|
||||||
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
|
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
|
||||||
|
|
||||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
||||||
sleep 15
|
sleep 15
|
||||||
@@ -96,7 +128,14 @@ vagrant ssh -c "ping -c1 192.168.100.2" -- -T
|
|||||||
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
|
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
|
||||||
docker exec host2 sh -c 'kill 1'
|
docker exec host2 sh -c 'kill 1'
|
||||||
docker exec lighthouse1 sh -c 'kill 1'
|
docker exec lighthouse1 sh -c 'kill 1'
|
||||||
sleep 1
|
|
||||||
|
# Wait up to 30s for all backgrounded jobs to exit. vagrant ssh in particular
|
||||||
|
# takes a beat to tear down after nebula exits on the VM, so a fixed sleep is
|
||||||
|
# racy.
|
||||||
|
for _ in $(seq 1 30); do
|
||||||
|
[ -z "$(jobs -r)" ] && break
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
if [ "$(jobs -r)" ]
|
if [ "$(jobs -r)" ]
|
||||||
then
|
then
|
||||||
|
|||||||
76
.github/workflows/smoke/smoke.sh
vendored
76
.github/workflows/smoke/smoke.sh
vendored
@@ -6,6 +6,8 @@ set -o pipefail
|
|||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
|
|
||||||
|
NETWORK="nebula-smoke"
|
||||||
|
|
||||||
cleanup() {
|
cleanup() {
|
||||||
echo
|
echo
|
||||||
echo " *** cleanup"
|
echo " *** cleanup"
|
||||||
@@ -16,38 +18,71 @@ cleanup() {
|
|||||||
then
|
then
|
||||||
docker kill lighthouse1 host2 host3 host4
|
docker kill lighthouse1 host2 host3 host4
|
||||||
fi
|
fi
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1
|
||||||
}
|
}
|
||||||
|
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
# Create a dedicated smoke network with an explicit subnet (required for --ip
|
||||||
|
# below). Probe a short list of candidates so a locally-used range doesn't
|
||||||
|
# fail the whole test — we only need one to be free.
|
||||||
|
docker network rm "$NETWORK" >/dev/null 2>&1 || true
|
||||||
|
for candidate in 172.30.0.0/24 172.31.0.0/24 10.98.0.0/24 10.99.0.0/24 192.168.230.0/24; do
|
||||||
|
if docker network create --subnet "$candidate" "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if ! docker network inspect "$NETWORK" >/dev/null 2>&1; then
|
||||||
|
echo "failed to create $NETWORK: every candidate subnet is in use" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Derive container IPs from the network's assigned subnet. Slots: .2 lighthouse1,
|
||||||
|
# .3 host2, .4 host3, .5 host4 — matches the placeholders in build.sh.
|
||||||
|
SUBNET="$(docker network inspect -f '{{(index .IPAM.Config 0).Subnet}}' "$NETWORK")"
|
||||||
|
PREFIX="${SUBNET%/*}"
|
||||||
|
PREFIX="${PREFIX%.*}"
|
||||||
|
LIGHTHOUSE_IP="$PREFIX.2"
|
||||||
|
HOST2_IP="$PREFIX.3"
|
||||||
|
HOST3_IP="$PREFIX.4"
|
||||||
|
HOST4_IP="$PREFIX.5"
|
||||||
|
|
||||||
|
# Sed the placeholder TEST-NET-3 IPs in the host configs to the real ones.
|
||||||
|
# build/lighthouse1.yml has no IPs to rewrite so it's skipped.
|
||||||
|
for f in build/host2.yml build/host3.yml build/host4.yml; do
|
||||||
|
sed "s|203\.0\.113\.|$PREFIX.|g" "$f" >"$f.tmp"
|
||||||
|
mv "$f.tmp" "$f"
|
||||||
|
done
|
||||||
|
|
||||||
CONTAINER="nebula:${NAME:-smoke}"
|
CONTAINER="nebula:${NAME:-smoke}"
|
||||||
|
|
||||||
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
||||||
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
|
docker run --name host2 --rm -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" "$CONTAINER" -config host2.yml -test
|
||||||
docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
|
docker run --name host3 --rm -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" "$CONTAINER" -config host3.yml -test
|
||||||
docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
|
docker run --name host4 --rm -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" "$CONTAINER" -config host4.yml -test
|
||||||
|
|
||||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
docker run --name lighthouse1 --network "$NETWORK" --ip "$LIGHTHOUSE_IP" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
docker run --name host2 --network "$NETWORK" --ip "$HOST2_IP" -v "$PWD/build/host2.yml:/nebula/host2.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
docker run --name host3 --network "$NETWORK" --ip "$HOST3_IP" -v "$PWD/build/host3.yml:/nebula/host3.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
|
docker run --name host4 --network "$NETWORK" --ip "$HOST4_IP" -v "$PWD/build/host4.yml:/nebula/host4.yml:ro" --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|
||||||
# grab tcpdump pcaps for debugging
|
# grab tcpdump pcaps for debugging
|
||||||
docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
|
docker exec lighthouse1 tcpdump -i tun0 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
|
||||||
docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
|
docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
|
||||||
docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
|
docker exec host2 tcpdump -i tun0 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
|
||||||
docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
|
docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
|
||||||
docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
|
docker exec host3 tcpdump -i tun0 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
|
||||||
docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
|
docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
|
||||||
docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
|
docker exec host4 tcpdump -i tun0 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
|
||||||
docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
|
docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
|
||||||
|
|
||||||
docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
||||||
docker exec host3 ncat -nklv 0.0.0.0 2000 &
|
docker exec host3 ncat -nklv 0.0.0.0 2000 &
|
||||||
|
docker exec host4 ncat -e '/usr/bin/echo helloagainfromhost4' -nkluv 0.0.0.0 4000 &
|
||||||
docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
|
docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
|
||||||
docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
|
docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
|
||||||
|
|
||||||
@@ -119,17 +154,24 @@ echo
|
|||||||
echo " *** Testing conntrack"
|
echo " *** Testing conntrack"
|
||||||
echo
|
echo
|
||||||
set -x
|
set -x
|
||||||
# host2 can ping host3 now that host3 pinged it first
|
|
||||||
docker exec host2 ping -c1 192.168.100.3
|
# host4's outbound firewall only allows ICMP to the lighthouse, so host4
|
||||||
# host4 can ping host2 once conntrack established
|
# cannot initiate UDP to host2. Once host2 initiates a flow to host4:4000,
|
||||||
docker exec host2 ping -c1 192.168.100.4
|
# conntrack must let host4's listener reply on that flow. If it doesn't,
|
||||||
docker exec host4 ping -c1 192.168.100.2
|
# the echo back from host4 never reaches host2.
|
||||||
|
docker exec host2 sh -c "(/usr/bin/echo host2; sleep 2) | ncat -nuv 192.168.100.4 4000" | grep -q helloagainfromhost4
|
||||||
|
|
||||||
docker exec host4 sh -c 'kill 1'
|
docker exec host4 sh -c 'kill 1'
|
||||||
docker exec host3 sh -c 'kill 1'
|
docker exec host3 sh -c 'kill 1'
|
||||||
docker exec host2 sh -c 'kill 1'
|
docker exec host2 sh -c 'kill 1'
|
||||||
docker exec lighthouse1 sh -c 'kill 1'
|
docker exec lighthouse1 sh -c 'kill 1'
|
||||||
sleep 5
|
|
||||||
|
# Wait up to 30s for all backgrounded jobs to exit rather than relying on a
|
||||||
|
# fixed sleep.
|
||||||
|
for _ in $(seq 1 30); do
|
||||||
|
[ -z "$(jobs -r)" ] && break
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
if [ "$(jobs -r)" ]
|
if [ "$(jobs -r)" ]
|
||||||
then
|
then
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# -*- mode: ruby -*-
|
# -*- mode: ruby -*-
|
||||||
# vi: set ft=ruby :
|
# vi: set ft=ruby :
|
||||||
Vagrant.configure("2") do |config|
|
Vagrant.configure("2") do |config|
|
||||||
config.vm.box = "ubuntu/jammy64"
|
config.vm.box = "bento/ubuntu-24.04"
|
||||||
|
|
||||||
config.vm.synced_folder "../build", "/nebula"
|
config.vm.synced_folder "../build", "/nebula"
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# -*- mode: ruby -*-
|
# -*- mode: ruby -*-
|
||||||
# vi: set ft=ruby :
|
# vi: set ft=ruby :
|
||||||
Vagrant.configure("2") do |config|
|
Vagrant.configure("2") do |config|
|
||||||
config.vm.box = "generic/openbsd7"
|
config.vm.box = "DefinedNet/openbsd78"
|
||||||
|
|
||||||
config.vm.synced_folder "../build", "/nebula", type: "rsync"
|
config.vm.synced_folder "../build", "/nebula", type: "rsync"
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -2,7 +2,21 @@ version: "2"
|
|||||||
linters:
|
linters:
|
||||||
default: none
|
default: none
|
||||||
enable:
|
enable:
|
||||||
|
- sloglint
|
||||||
- testifylint
|
- testifylint
|
||||||
|
settings:
|
||||||
|
sloglint:
|
||||||
|
# Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and
|
||||||
|
# the package-level slog equivalents. Use l.Log(ctx, level, ...) for
|
||||||
|
# custom levels instead of LogAttrs when you can.
|
||||||
|
#
|
||||||
|
# LogAttrs is also flagged by this rule because it takes ...slog.Attr;
|
||||||
|
# the few legitimate sites (where attrs is built up as a []slog.Attr)
|
||||||
|
# carry a //nolint:sloglint with rationale.
|
||||||
|
kv-only: true
|
||||||
|
# no-mixed-args is on by default: forbids mixing kv and attrs in one call.
|
||||||
|
# discard-handler is on by default (since Go 1.24): suggests
|
||||||
|
# slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil).
|
||||||
exclusions:
|
exclusions:
|
||||||
generated: lax
|
generated: lax
|
||||||
presets:
|
presets:
|
||||||
|
|||||||
16
CHANGELOG.md
16
CHANGELOG.md
@@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [1.10.3] - 2026-02-06
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
- Fix an issue where blocklist bypass is possible when using curve P256 since the signature can have 2 valid representations.
|
||||||
|
Both fingerprint representations will be tested against the blocklist.
|
||||||
|
Any newly issued P256 based certificates will have their signature clamped to the low-s form.
|
||||||
|
Nebula will assert the low-s signature form when validating certificates in a future version. [GHSA-69x3-g4r3-p962](https://github.com/slackhq/nebula/security/advisories/GHSA-69x3-g4r3-p962)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Improve error reporting if nebula fails to start due to a tun device naming issue. (#1588)
|
||||||
|
|
||||||
## [1.10.2] - 2026-01-21
|
## [1.10.2] - 2026-01-21
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
@@ -775,7 +788,8 @@ created.)
|
|||||||
|
|
||||||
- Initial public release.
|
- Initial public release.
|
||||||
|
|
||||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.2...HEAD
|
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.3...HEAD
|
||||||
|
[1.10.3]: https://github.com/slackhq/nebula/releases/tag/v1.10.3
|
||||||
[1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2
|
[1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2
|
||||||
[1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1
|
[1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1
|
||||||
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
|
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
|
||||||
|
|||||||
1
CODEOWNERS
Normal file
1
CODEOWNERS
Normal file
@@ -0,0 +1 @@
|
|||||||
|
#ECCN:Open Source
|
||||||
@@ -57,7 +57,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
|
|||||||
docker pull nebulaoss/nebula
|
docker pull nebulaoss/nebula
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Mobile
|
#### Mobile ([source code](https://github.com/DefinedNet/mobile_nebula))
|
||||||
|
|
||||||
- [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200)
|
- [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200)
|
||||||
- [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1)
|
- [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1)
|
||||||
@@ -76,6 +76,8 @@ Nebula was created to provide a mechanism for groups of hosts to communicate sec
|
|||||||
|
|
||||||
## Getting started (quickly)
|
## Getting started (quickly)
|
||||||
|
|
||||||
|
**Don't want to manage your own PKI and lighthouses?** [Managed Nebula](https://www.defined.net/) from Defined Networking handles all of this for you.
|
||||||
|
|
||||||
To set up a Nebula network, you'll need:
|
To set up a Nebula network, you'll need:
|
||||||
|
|
||||||
#### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use.
|
#### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use.
|
||||||
|
|||||||
38
bits.go
38
bits.go
@@ -1,8 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Bits struct {
|
type Bits struct {
|
||||||
@@ -30,7 +32,7 @@ func NewBits(bits uint64) *Bits {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
|
||||||
// If i is the next number, return true.
|
// If i is the next number, return true.
|
||||||
if i > b.current {
|
if i > b.current {
|
||||||
return true
|
return true
|
||||||
@@ -42,13 +44,16 @@ func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
l.Debug("rejected a packet (top)",
|
||||||
|
"current", b.current,
|
||||||
|
"incoming", i,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
|
||||||
// If i is the next number, return true and update current.
|
// If i is the next number, return true and update current.
|
||||||
if i == b.current+1 {
|
if i == b.current+1 {
|
||||||
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
||||||
@@ -87,9 +92,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
// Check to see if it's a duplicate
|
// Check to see if it's a duplicate
|
||||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
if i > b.current-b.length || i < b.length && b.current < b.length {
|
||||||
if b.current == i || b.bits[i%b.length] == true {
|
if b.current == i || b.bits[i%b.length] == true {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
l.Debug("Receive window",
|
||||||
Debug("Receive window")
|
"accepted", false,
|
||||||
|
"currentCounter", b.current,
|
||||||
|
"incomingCounter", i,
|
||||||
|
"reason", "duplicate",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
b.dupeCounter.Inc(1)
|
b.dupeCounter.Inc(1)
|
||||||
return false
|
return false
|
||||||
@@ -101,12 +110,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
|
|
||||||
// In all other cases, fail and don't change current.
|
// In all other cases, fail and don't change current.
|
||||||
b.outOfWindowCounter.Inc(1)
|
b.outOfWindowCounter.Inc(1)
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("accepted", false).
|
l.Debug("Receive window",
|
||||||
WithField("currentCounter", b.current).
|
"accepted", false,
|
||||||
WithField("incomingCounter", i).
|
"currentCounter", b.current,
|
||||||
WithField("reason", "nonsense").
|
"incomingCounter", i,
|
||||||
Debug("Receive window")
|
"reason", "nonsense",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build boringcrypto
|
//go:build boringcrypto
|
||||||
// +build boringcrypto
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,22 +32,46 @@ func NewCAPool() *CAPool {
|
|||||||
// If the pool contains any expired certificates, an ErrExpired will be
|
// If the pool contains any expired certificates, an ErrExpired will be
|
||||||
// returned along with the pool. The caller must handle any such errors.
|
// returned along with the pool. The caller must handle any such errors.
|
||||||
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
|
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
|
||||||
|
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
|
||||||
|
// The reader must contain a PEM-encoded set of nebula certificates.
|
||||||
|
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
|
||||||
pool := NewCAPool()
|
pool := NewCAPool()
|
||||||
var err error
|
|
||||||
var expired bool
|
var expired bool
|
||||||
for {
|
|
||||||
caPEMs, err = pool.AddCAFromPEM(caPEMs)
|
scanner := bufio.NewScanner(r)
|
||||||
if errors.Is(err, ErrExpired) {
|
scanner.Split(SplitPEM)
|
||||||
expired = true
|
|
||||||
err = nil
|
for scanner.Scan() {
|
||||||
|
pemBytes := scanner.Bytes()
|
||||||
|
|
||||||
|
block, rest := pem.Decode(pemBytes)
|
||||||
|
if len(bytes.TrimSpace(rest)) > 0 {
|
||||||
|
return nil, ErrInvalidPEMBlock
|
||||||
}
|
}
|
||||||
|
if block == nil {
|
||||||
|
return nil, ErrInvalidPEMBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := unmarshalCertificateBlock(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
|
|
||||||
break
|
err = pool.AddCA(c)
|
||||||
|
if errors.Is(err, ErrExpired) {
|
||||||
|
expired = true
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, ErrInvalidPEMBlock
|
||||||
|
}
|
||||||
|
|
||||||
if expired {
|
if expired {
|
||||||
return pool, ErrExpired
|
return pool, ErrExpired
|
||||||
@@ -141,10 +168,23 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pre nebula v1.10.3 could generate signatures in either high or low s form and validation
|
||||||
|
// of signatures allowed for either. Nebula v1.10.3 and beyond clamps signature generation to low-s form
|
||||||
|
// but validation still allows for either. Since a change in the signature bytes affects the fingerprint, we
|
||||||
|
// need to test both forms until such a time comes that we enforce low-s form on signature validation.
|
||||||
|
fp2, err := CalculateAlternateFingerprint(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not calculate alternate fingerprint to verify: %w", err)
|
||||||
|
}
|
||||||
|
if fp2 != "" && ncp.IsBlocklisted(fp2) {
|
||||||
|
return nil, ErrBlockListed
|
||||||
|
}
|
||||||
|
|
||||||
cc := CachedCertificate{
|
cc := CachedCertificate{
|
||||||
Certificate: c,
|
Certificate: c,
|
||||||
InvertedGroups: make(map[string]struct{}),
|
InvertedGroups: make(map[string]struct{}),
|
||||||
Fingerprint: fp,
|
Fingerprint: fp,
|
||||||
|
fingerprint2: fp2,
|
||||||
signerFingerprint: signer.Fingerprint,
|
signerFingerprint: signer.Fingerprint,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,6 +198,11 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti
|
|||||||
// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
|
// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
|
||||||
// is a cheaper operation to perform as a result.
|
// is a cheaper operation to perform as a result.
|
||||||
func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
|
func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
|
||||||
|
// Check any available alternate fingerprint forms for this certificate, re P256 high-s/low-s
|
||||||
|
if c.fingerprint2 != "" && ncp.IsBlocklisted(c.fingerprint2) {
|
||||||
|
return ErrBlockListed
|
||||||
|
}
|
||||||
|
|
||||||
_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
|
_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert/p256"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -111,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
|
|||||||
assert.Len(t, ppppp.CAs, 1)
|
assert.Len(t, ppppp.CAs, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oneByteReader wraps a reader to return at most 1 byte per Read call,
|
||||||
|
// exercising the streaming accumulation logic in NewCAPoolFromPEMReader.
|
||||||
|
type oneByteReader struct {
|
||||||
|
r io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *oneByteReader) Read(p []byte) (int, error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return o.r.Read(p[:1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) {
|
||||||
|
pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, pool.CAs)
|
||||||
|
|
||||||
|
pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n "))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, pool.CAs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) {
|
||||||
|
ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
|
ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
|
|
||||||
|
bundle := append(pem1, pem2...)
|
||||||
|
pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, pool.CAs, 2)
|
||||||
|
|
||||||
|
fp1, err := ca1.Fingerprint()
|
||||||
|
require.NoError(t, err)
|
||||||
|
fp2, err := ca2.Fingerprint()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Contains(t, pool.CAs, fp1)
|
||||||
|
assert.Contains(t, pool.CAs, fp2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) {
|
||||||
|
_, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata"))
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) {
|
||||||
|
_, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
|
|
||||||
|
bundle := append(pem1, []byte("some trailing garbage")...)
|
||||||
|
_, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle))
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCertificateV1_Verify(t *testing.T) {
|
func TestCertificateV1_Verify(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
@@ -170,6 +228,15 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.EqualError(t, err, "certificate is in the block list")
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
|
// Create a copy of the cert and swap to the alternate form for the signature
|
||||||
|
nc := c.Copy()
|
||||||
|
b, err := p256.Swap(c.Signature())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, nc.(*certificateV1).setSignature(b))
|
||||||
|
|
||||||
|
_, err = caPool.VerifyCertificate(time.Now(), nc)
|
||||||
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -187,7 +254,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err = caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
@@ -196,7 +263,17 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
cc, err := caPool.VerifyCertificate(time.Now(), c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reset the blocklist and block the alternate form fingerprint
|
||||||
|
caPool.ResetCertBlocklist()
|
||||||
|
caPool.BlocklistFingerprint(cc.fingerprint2)
|
||||||
|
err = caPool.VerifyCachedCertificate(time.Now(), cc)
|
||||||
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
|
caPool.ResetCertBlocklist()
|
||||||
|
err = caPool.VerifyCachedCertificate(time.Now(), cc)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -394,6 +471,15 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.EqualError(t, err, "certificate is in the block list")
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
|
// Create a copy of the cert and swap to the alternate form for the signature
|
||||||
|
nc := c.Copy()
|
||||||
|
b, err := p256.Swap(c.Signature())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, nc.(*certificateV2).setSignature(b))
|
||||||
|
|
||||||
|
_, err = caPool.VerifyCertificate(time.Now(), nc)
|
||||||
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -411,7 +497,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err = caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
@@ -420,7 +506,17 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
cc, err := caPool.VerifyCertificate(time.Now(), c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reset the blocklist and block the alternate form fingerprint
|
||||||
|
caPool.ResetCertBlocklist()
|
||||||
|
caPool.BlocklistFingerprint(cc.fingerprint2)
|
||||||
|
err = caPool.VerifyCachedCertificate(time.Now(), cc)
|
||||||
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
|
caPool.ResetCertBlocklist()
|
||||||
|
err = caPool.VerifyCachedCertificate(time.Now(), cc)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
33
cert/cert.go
33
cert/cert.go
@@ -4,6 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert/p256"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Version uint8
|
type Version uint8
|
||||||
@@ -110,6 +112,9 @@ type CachedCertificate struct {
|
|||||||
InvertedGroups map[string]struct{}
|
InvertedGroups map[string]struct{}
|
||||||
Fingerprint string
|
Fingerprint string
|
||||||
signerFingerprint string
|
signerFingerprint string
|
||||||
|
|
||||||
|
// A place to store a 2nd fingerprint if the certificate could have one, such as with P256
|
||||||
|
fingerprint2 string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *CachedCertificate) String() string {
|
func (cc *CachedCertificate) String() string {
|
||||||
@@ -152,3 +157,31 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
|
|||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates
|
||||||
|
// CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step.
|
||||||
|
func CalculateAlternateFingerprint(c Certificate) (string, error) {
|
||||||
|
if c.Curve() != Curve_P256 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nc := c.Copy()
|
||||||
|
b, err := p256.Swap(nc.Signature())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := nc.(type) {
|
||||||
|
case *certificateV1:
|
||||||
|
err = v.setSignature(b)
|
||||||
|
case *certificateV2:
|
||||||
|
err = v.setSignature(b)
|
||||||
|
default:
|
||||||
|
return "", ErrUnknownVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return nc.Fingerprint()
|
||||||
|
}
|
||||||
|
|||||||
127
cert/p256/p256.go
Normal file
127
cert/p256/p256.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package p256
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/elliptic"
|
||||||
|
"errors"
|
||||||
|
"math/big"
|
||||||
|
|
||||||
|
"filippo.io/bigmod"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/cryptobyte"
|
||||||
|
"golang.org/x/crypto/cryptobyte/asn1"
|
||||||
|
)
|
||||||
|
|
||||||
|
var halfN = new(big.Int).Rsh(elliptic.P256().Params().N, 1)
|
||||||
|
var nMod *bigmod.Modulus
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
n, err := bigmod.NewModulus(elliptic.P256().Params().N.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
nMod = n
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsNormalized(sig []byte) (bool, error) {
|
||||||
|
r, s, err := parseSignature(sig)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return checkLowS(r, s), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkLowS(_, s []byte) bool {
|
||||||
|
bigS := new(big.Int).SetBytes(s)
|
||||||
|
// Check if S <= (N/2), because we want to include the midpoint in the set of low-s
|
||||||
|
return bigS.Cmp(halfN) <= 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func swap(r, s []byte) ([]byte, []byte, error) {
|
||||||
|
var err error
|
||||||
|
bigS, err := bigmod.NewNat().SetBytes(s, nMod)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
sNormalized := nMod.Nat().Sub(bigS, nMod)
|
||||||
|
|
||||||
|
result := sNormalized.Bytes(nMod)
|
||||||
|
for len(result) > 1 && result[0] == 0 {
|
||||||
|
result = result[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Normalize(sig []byte) ([]byte, error) {
|
||||||
|
r, s, err := parseSignature(sig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if checkLowS(r, s) {
|
||||||
|
return sig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newR, newS, err := swap(r, s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return encodeSignature(newR, newS)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap will change sig between its current form to the opposite high or low form.
|
||||||
|
func Swap(sig []byte) ([]byte, error) {
|
||||||
|
r, s, err := parseSignature(sig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newR, newS, err := swap(r, s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return encodeSignature(newR, newS)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSignature taken exactly from crypto/ecdsa/ecdsa.go
|
||||||
|
func parseSignature(sig []byte) (r, s []byte, err error) {
|
||||||
|
var inner cryptobyte.String
|
||||||
|
input := cryptobyte.String(sig)
|
||||||
|
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
|
||||||
|
!input.Empty() ||
|
||||||
|
!inner.ReadASN1Integer(&r) ||
|
||||||
|
!inner.ReadASN1Integer(&s) ||
|
||||||
|
!inner.Empty() {
|
||||||
|
return nil, nil, errors.New("invalid ASN.1")
|
||||||
|
}
|
||||||
|
return r, s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeSignature(r, s []byte) ([]byte, error) {
|
||||||
|
var b cryptobyte.Builder
|
||||||
|
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||||
|
addASN1IntBytes(b, r)
|
||||||
|
addASN1IntBytes(b, s)
|
||||||
|
})
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// addASN1IntBytes encodes in ASN.1 a positive integer represented as
|
||||||
|
// a big-endian byte slice with zero or more leading zeroes.
|
||||||
|
func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
|
||||||
|
for len(bytes) > 0 && bytes[0] == 0 {
|
||||||
|
bytes = bytes[1:]
|
||||||
|
}
|
||||||
|
if len(bytes) == 0 {
|
||||||
|
b.SetError(errors.New("invalid integer"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) {
|
||||||
|
if bytes[0]&0x80 != 0 {
|
||||||
|
c.AddUint8(0)
|
||||||
|
}
|
||||||
|
c.AddBytes(bytes)
|
||||||
|
})
|
||||||
|
}
|
||||||
28
cert/p256/p256_test.go
Normal file
28
cert/p256/p256_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package p256
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFlipping(t *testing.T) {
|
||||||
|
priv, err1 := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
|
||||||
|
out, err := ecdsa.SignASN1(rand.Reader, priv, []byte("big chungus"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r, s, err := parseSignature(out)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r, s1, err := swap(r, s)
|
||||||
|
require.NoError(t, err)
|
||||||
|
r, s2, err := swap(r, s1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, s, s2)
|
||||||
|
require.NotEqual(t, s, s1)
|
||||||
|
}
|
||||||
82
cert/pem.go
82
cert/pem.go
@@ -1,12 +1,66 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrTruncatedPEMBlock = errors.New("truncated PEM block")
|
||||||
|
|
||||||
|
// SplitPEM is a split function for bufio.Scanner that returns each PEM block.
|
||||||
|
func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
// Look for the start of a PEM block
|
||||||
|
start := bytes.Index(data, []byte("-----BEGIN "))
|
||||||
|
if start == -1 {
|
||||||
|
if atEOF && len(bytes.TrimSpace(data)) > 0 {
|
||||||
|
// Non-whitespace content with no PEM block
|
||||||
|
return 0, nil, ErrTruncatedPEMBlock
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), nil, nil
|
||||||
|
}
|
||||||
|
// Request more data
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for the end marker
|
||||||
|
endMarkerStart := bytes.Index(data[start:], []byte("-----END "))
|
||||||
|
if endMarkerStart == -1 {
|
||||||
|
if atEOF {
|
||||||
|
// Incomplete PEM block at EOF
|
||||||
|
return 0, nil, ErrTruncatedPEMBlock
|
||||||
|
}
|
||||||
|
// Need more data to find the end
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the actual end of the END line (after the newline)
|
||||||
|
endMarkerStart += start
|
||||||
|
endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n')
|
||||||
|
var end int
|
||||||
|
if endLineEnd == -1 {
|
||||||
|
if atEOF {
|
||||||
|
// END marker without newline at EOF - take it anyway
|
||||||
|
end = len(data)
|
||||||
|
} else {
|
||||||
|
// Need more data
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
end = endMarkerStart + endLineEnd + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the PEM block
|
||||||
|
pemBlock := data[start:end]
|
||||||
|
|
||||||
|
// Return the valid PEM block
|
||||||
|
return end, pemBlock, nil
|
||||||
|
}
|
||||||
|
|
||||||
const ( //cert banners
|
const ( //cert banners
|
||||||
CertificateBanner = "NEBULA CERTIFICATE"
|
CertificateBanner = "NEBULA CERTIFICATE"
|
||||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||||
@@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
return nil, r, ErrInvalidPEMBlock
|
return nil, r, ErrInvalidPEMBlock
|
||||||
}
|
}
|
||||||
|
|
||||||
var c Certificate
|
c, err := unmarshalCertificateBlock(p)
|
||||||
var err error
|
|
||||||
|
|
||||||
switch p.Type {
|
|
||||||
// Implementations must validate the resulting certificate contains valid information
|
|
||||||
case CertificateBanner:
|
|
||||||
c, err = unmarshalCertificateV1(p.Bytes, nil)
|
|
||||||
case CertificateV2Banner:
|
|
||||||
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
|
|
||||||
default:
|
|
||||||
return nil, r, ErrInvalidPEMCertificateBanner
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, r, err
|
return nil, r, err
|
||||||
}
|
}
|
||||||
@@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
|
||||||
|
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
|
||||||
|
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
|
||||||
|
switch block.Type {
|
||||||
|
// Implementations must validate the resulting certificate contains valid information
|
||||||
|
case CertificateBanner:
|
||||||
|
return unmarshalCertificateV1(block.Bytes, nil)
|
||||||
|
case CertificateV2Banner:
|
||||||
|
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
|
||||||
|
default:
|
||||||
|
return nil, ErrInvalidPEMCertificateBanner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
||||||
if c.IsCA() {
|
if c.IsCA() {
|
||||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||||
|
|||||||
@@ -1,12 +1,88 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func scanAll(t *testing.T, input string) ([]string, error) {
|
||||||
|
t.Helper()
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(input))
|
||||||
|
scanner.Split(SplitPEM)
|
||||||
|
var blocks []string
|
||||||
|
for scanner.Scan() {
|
||||||
|
blocks = append(blocks, scanner.Text())
|
||||||
|
}
|
||||||
|
return blocks, scanner.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_Single(t *testing.T) {
|
||||||
|
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n"
|
||||||
|
blocks, err := scanAll(t, input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, blocks, 1)
|
||||||
|
require.Equal(t, input, blocks[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_Multiple(t *testing.T) {
|
||||||
|
block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n"
|
||||||
|
block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
|
||||||
|
blocks, err := scanAll(t, block1+block2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, blocks, 2)
|
||||||
|
require.Equal(t, block1, blocks[0])
|
||||||
|
require.Equal(t, block2, blocks[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) {
|
||||||
|
input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
|
||||||
|
blocks, err := scanAll(t, input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, blocks, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_Empty(t *testing.T) {
|
||||||
|
blocks, err := scanAll(t, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_WhitespaceOnly(t *testing.T) {
|
||||||
|
blocks, err := scanAll(t, " \n\t\n ")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_TrailingGarbage(t *testing.T) {
|
||||||
|
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage"
|
||||||
|
blocks, err := scanAll(t, input)
|
||||||
|
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
|
||||||
|
require.Len(t, blocks, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_TruncatedBlock(t *testing.T) {
|
||||||
|
input := "-----BEGIN TEST-----\npartial data with no end"
|
||||||
|
_, err := scanAll(t, input)
|
||||||
|
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_NoEndNewline(t *testing.T) {
|
||||||
|
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----"
|
||||||
|
blocks, err := scanAll(t, input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, blocks, 1)
|
||||||
|
require.Equal(t, input, blocks[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitPEM_GarbageOnly(t *testing.T) {
|
||||||
|
_, err := scanAll(t, "this is not PEM data")
|
||||||
|
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
||||||
goodCert := []byte(`
|
goodCert := []byte(`
|
||||||
# A good cert
|
# A good cert
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert/p256"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TBSCertificate represents a certificate intended to be signed.
|
// TBSCertificate represents a certificate intended to be signed.
|
||||||
@@ -126,6 +128,13 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if curve == Curve_P256 {
|
||||||
|
sig, err = p256.Normalize(sig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = c.setSignature(sig)
|
err = c.setSignature(sig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert/p256"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -89,3 +90,48 @@ func TestCertificateV1_SignP256(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, uc)
|
assert.NotNil(t, uc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCertificate_SignP256_AlwaysNormalized(t *testing.T) {
|
||||||
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
|
pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
|
||||||
|
|
||||||
|
tbs := TBSCertificate{
|
||||||
|
Version: Version1,
|
||||||
|
Name: "testing",
|
||||||
|
Networks: []netip.Prefix{
|
||||||
|
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||||
|
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||||
|
},
|
||||||
|
UnsafeNetworks: []netip.Prefix{
|
||||||
|
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||||
|
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||||
|
},
|
||||||
|
Groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||||
|
NotBefore: before,
|
||||||
|
NotAfter: after,
|
||||||
|
PublicKey: pubKey,
|
||||||
|
IsCA: true,
|
||||||
|
Curve: Curve_P256,
|
||||||
|
}
|
||||||
|
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
|
||||||
|
rawPriv := priv.D.FillBytes(make([]byte, 32))
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
if i&1 == 1 {
|
||||||
|
tbs.Version = Version1
|
||||||
|
} else {
|
||||||
|
tbs.Version = Version2
|
||||||
|
}
|
||||||
|
c, err := tbs.Sign(nil, Curve_P256, rawPriv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.True(t, c.CheckSignature(pub))
|
||||||
|
normie, err := p256.IsNormalized(c.Signature())
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, normie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
@@ -40,21 +39,15 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rawCACert, err := os.ReadFile(*vf.caPath)
|
caFile, err := os.Open(*vf.caPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error while reading ca: %w", err)
|
return fmt.Errorf("error while reading ca: %w", err)
|
||||||
}
|
}
|
||||||
|
defer caFile.Close()
|
||||||
|
|
||||||
caPool := cert.NewCAPool()
|
caPool, err := cert.NewCAPoolFromPEMReader(caFile)
|
||||||
for {
|
if err != nil && !errors.Is(err, cert.ErrExpired) {
|
||||||
rawCACert, err = caPool.AddCAFromPEM(rawCACert)
|
return fmt.Errorf("error while adding ca cert to pool: %w", err)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while adding ca cert to pool: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rawCert, err := os.ReadFile(*vf.certPath)
|
rawCert, err := os.ReadFile(*vf.certPath)
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func Test_verify(t *testing.T) {
|
|||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
|
require.ErrorIs(t, err, cert.ErrInvalidPEMBlock)
|
||||||
|
|
||||||
// make a ca for later
|
// make a ca for later
|
||||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
|||||||
@@ -3,8 +3,15 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "github.com/sirupsen/logrus"
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
func HookLogger(l *logrus.Logger) {
|
"github.com/slackhq/nebula/logging"
|
||||||
// Do nothing, let the logs flow to stdout/stderr
|
)
|
||||||
|
|
||||||
|
// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows
|
||||||
|
// platforms have no special sink to integrate with.
|
||||||
|
func newPlatformLogger() *slog.Logger {
|
||||||
|
return logging.NewLogger(os.Stdout)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,54 +1,86 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"context"
|
||||||
"io/ioutil"
|
"log/slog"
|
||||||
"os"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer
|
// newPlatformLogger returns a *slog.Logger that routes every log record
|
||||||
// logrus output will be discarded
|
// through the Windows service logger so records end up in the Windows
|
||||||
func HookLogger(l *logrus.Logger) {
|
// Event Log. All the heavy lifting (level management, format swap,
|
||||||
l.AddHook(newLogHook(logger))
|
// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler;
|
||||||
l.SetOutput(ioutil.Discard)
|
// this file only contributes:
|
||||||
|
//
|
||||||
|
// - an io.Writer that forwards each formatted line to the service
|
||||||
|
// logger at the current record's Event Log severity, and
|
||||||
|
// - a thin severityTag that embeds *logging.Handler and overrides
|
||||||
|
// only Handle / WithAttrs / WithGroup, so Event Viewer's severity
|
||||||
|
// column and severity-based filters keep working the way they did
|
||||||
|
// before the slog migration.
|
||||||
|
//
|
||||||
|
// Format (text vs json) is carried by the embedded *logging.Handler, so
|
||||||
|
// logging.format: json in config still produces JSON lines in Event
|
||||||
|
// Viewer, same as the pre-slog logrus setup.
|
||||||
|
func newPlatformLogger() *slog.Logger {
|
||||||
|
w := &eventLogWriter{}
|
||||||
|
return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w})
|
||||||
}
|
}
|
||||||
|
|
||||||
type logHook struct {
|
// eventLogWriter forwards slog-formatted lines to the Windows service
|
||||||
sl service.Logger
|
// logger at the severity most recently stashed by severityTag.Handle.
|
||||||
|
// The mutex serializes the stash + inner.Handle + Write cycle per record
|
||||||
|
// across all concurrent goroutines; slog's builtin text/json handlers
|
||||||
|
// each hold their own mutex around Write, but that only protects the
|
||||||
|
// Write call itself, not our stash-then-handle sequence.
|
||||||
|
type eventLogWriter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
level slog.Level
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLogHook(sl service.Logger) *logHook {
|
func (w *eventLogWriter) Write(p []byte) (int, error) {
|
||||||
return &logHook{sl: sl}
|
line := strings.TrimRight(string(p), "\n")
|
||||||
}
|
switch {
|
||||||
|
case w.level >= slog.LevelError:
|
||||||
func (h *logHook) Fire(entry *logrus.Entry) error {
|
return len(p), logger.Error(line)
|
||||||
line, err := entry.String()
|
case w.level >= slog.LevelWarn:
|
||||||
if err != nil {
|
return len(p), logger.Warning(line)
|
||||||
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch entry.Level {
|
|
||||||
case logrus.PanicLevel:
|
|
||||||
return h.sl.Error(line)
|
|
||||||
case logrus.FatalLevel:
|
|
||||||
return h.sl.Error(line)
|
|
||||||
case logrus.ErrorLevel:
|
|
||||||
return h.sl.Error(line)
|
|
||||||
case logrus.WarnLevel:
|
|
||||||
return h.sl.Warning(line)
|
|
||||||
case logrus.InfoLevel:
|
|
||||||
return h.sl.Info(line)
|
|
||||||
case logrus.DebugLevel:
|
|
||||||
return h.sl.Info(line)
|
|
||||||
default:
|
default:
|
||||||
return nil
|
return len(p), logger.Info(line)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *logHook) Levels() []logrus.Level {
|
// severityTag embeds *logging.Handler to pick up everything it does for
|
||||||
return logrus.AllLevels
|
// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat,
|
||||||
|
// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup
|
||||||
|
// so each record's slog.Level is stashed on the writer before formatting
|
||||||
|
// and so derived handlers stay wrapped as severityTag rather than
|
||||||
|
// downgrading to bare *logging.Handler.
|
||||||
|
type severityTag struct {
|
||||||
|
*logging.Handler
|
||||||
|
w *eventLogWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *severityTag) Handle(ctx context.Context, r slog.Record) error {
|
||||||
|
s.w.mu.Lock()
|
||||||
|
defer s.w.mu.Unlock()
|
||||||
|
s.w.level = r.Level
|
||||||
|
return s.Handler.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||||
|
if len(attrs) == 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *severityTag) WithGroup(name string) slog.Handler {
|
||||||
|
if name == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,9 +50,14 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
l := logging.NewLogger(os.Stdout)
|
||||||
|
|
||||||
if *serviceFlag != "" {
|
if *serviceFlag != "" {
|
||||||
doService(configPath, configTest, Build, serviceFlag)
|
if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
|
||||||
os.Exit(1)
|
l.Error("Service command failed", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if *configPath == "" {
|
if *configPath == "" {
|
||||||
@@ -61,9 +66,6 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logrus.New()
|
|
||||||
l.Out = os.Stdout
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
err := c.Load(*configPath)
|
err := c.Load(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -71,6 +73,16 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
fmt.Printf("failed to apply logging config: %s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||||
@@ -78,8 +90,20 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
ctrl.ShutdownBlock()
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
|
|
||||||
|
if err := wait(); err != nil {
|
||||||
|
l.Error("Nebula stopped due to fatal error", "error", err)
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger service.Logger
|
var logger service.Logger
|
||||||
@@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error {
|
|||||||
// Start should not block.
|
// Start should not block.
|
||||||
logger.Info("Nebula service starting.")
|
logger.Info("Nebula service starting.")
|
||||||
|
|
||||||
l := logrus.New()
|
l := newPlatformLogger()
|
||||||
HookLogger(l)
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
err := c.Load(*p.configPath)
|
err := c.Load(*p.configPath)
|
||||||
@@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error {
|
|||||||
return fmt.Errorf("failed to load config: %s", err)
|
return fmt.Errorf("failed to load config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
return fmt.Errorf("failed to apply logging config: %s", err)
|
||||||
|
}
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
|
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -57,11 +65,11 @@ func fileExists(filename string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
|
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
|
||||||
if *configPath == "" {
|
if *configPath == "" {
|
||||||
ex, err := os.Executable()
|
ex, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return err
|
||||||
}
|
}
|
||||||
*configPath = filepath.Dir(ex) + "/config.yaml"
|
*configPath = filepath.Dir(ex) + "/config.yaml"
|
||||||
if !fileExists(*configPath) {
|
if !fileExists(*configPath) {
|
||||||
@@ -85,16 +93,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
|
|||||||
// Here are what the different loggers are doing:
|
// Here are what the different loggers are doing:
|
||||||
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
|
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
|
||||||
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
|
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
|
||||||
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use
|
// - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger
|
||||||
s, err := service.New(prg, svcConfig)
|
s, err := service.New(prg, svcConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
errs := make(chan error, 5)
|
errs := make(chan error, 5)
|
||||||
logger, err = s.Logger(errs)
|
logger, err = s.Logger(errs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -109,18 +117,16 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
|
|||||||
|
|
||||||
switch *serviceFlag {
|
switch *serviceFlag {
|
||||||
case "run":
|
case "run":
|
||||||
err = s.Run()
|
if err := s.Run(); err != nil {
|
||||||
if err != nil {
|
|
||||||
// Route any errors to the system logger
|
// Route any errors to the system logger
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
err := service.Control(s, *serviceFlag)
|
if err := service.Control(s, *serviceFlag); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Printf("Valid actions: %q\n", service.ControlAction)
|
log.Printf("Valid actions: %q\n", service.ControlAction)
|
||||||
log.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,8 +55,7 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logrus.New()
|
l := logging.NewLogger(os.Stdout)
|
||||||
l.Out = os.Stdout
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
err := c.Load(*configPath)
|
err := c.Load(*configPath)
|
||||||
@@ -65,6 +64,16 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
fmt.Printf("failed to apply logging config: %s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||||
@@ -72,9 +81,21 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
ctrl.ShutdownBlock()
|
|
||||||
|
if err := wait(); err != nil {
|
||||||
|
l.Error("Nebula stopped due to fatal error", "error", err)
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
|
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
|
||||||
@@ -13,30 +12,30 @@ import (
|
|||||||
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
|
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
|
||||||
const SdNotifyReady = "READY=1"
|
const SdNotifyReady = "READY=1"
|
||||||
|
|
||||||
func notifyReady(l *logrus.Logger) {
|
func notifyReady(l *slog.Logger) {
|
||||||
sockName := os.Getenv("NOTIFY_SOCKET")
|
sockName := os.Getenv("NOTIFY_SOCKET")
|
||||||
if sockName == "" {
|
if sockName == "" {
|
||||||
l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
|
l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
|
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("failed to connect to systemd notification socket")
|
l.Error("failed to connect to systemd notification socket", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
|
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("failed to set the write deadline for the systemd notification socket")
|
l.Error("failed to set the write deadline for the systemd notification socket", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
|
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
|
||||||
l.WithError(err).Error("failed to signal the systemd notification socket")
|
l.Error("failed to signal the systemd notification socket", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Debugln("notified systemd the service is ready")
|
l.Debug("notified systemd the service is ready")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "github.com/sirupsen/logrus"
|
import "log/slog"
|
||||||
|
|
||||||
func notifyReady(_ *logrus.Logger) {
|
func notifyReady(_ *slog.Logger) {
|
||||||
// No init service to notify
|
// No init service to notify
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -16,7 +17,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"go.yaml.in/yaml/v3"
|
"go.yaml.in/yaml/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,11 +26,11 @@ type C struct {
|
|||||||
Settings map[string]any
|
Settings map[string]any
|
||||||
oldSettings map[string]any
|
oldSettings map[string]any
|
||||||
callbacks []func(*C)
|
callbacks []func(*C)
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
reloadLock sync.Mutex
|
reloadLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewC(l *logrus.Logger) *C {
|
func NewC(l *slog.Logger) *C {
|
||||||
return &C{
|
return &C{
|
||||||
Settings: make(map[string]any),
|
Settings: make(map[string]any),
|
||||||
l: l,
|
l: l,
|
||||||
@@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool {
|
|||||||
|
|
||||||
newVals, err := yaml.Marshal(nv)
|
newVals, err := yaml.Marshal(nv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
c.l.Error("Error while marshaling new config",
|
||||||
|
"config_path", k,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
oldVals, err := yaml.Marshal(ov)
|
oldVals, err := yaml.Marshal(ov)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
c.l.Error("Error while marshaling old config",
|
||||||
|
"config_path", k,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(newVals) != string(oldVals)
|
return string(newVals) != string(oldVals)
|
||||||
@@ -154,7 +160,10 @@ func (c *C) ReloadConfig() {
|
|||||||
|
|
||||||
err := c.Load(c.path)
|
err := c.Load(c.path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
c.l.Error("Error occurred while reloading config",
|
||||||
|
"config_path", c.path,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@@ -47,10 +47,10 @@ type connectionManager struct {
|
|||||||
|
|
||||||
metricsTxPunchy metrics.Counter
|
metricsTxPunchy metrics.Counter
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
||||||
cm := &connectionManager{
|
cm := &connectionManager{
|
||||||
hostMap: hm,
|
hostMap: hm,
|
||||||
l: l,
|
l: l,
|
||||||
@@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
|
|||||||
old := cm.getInactivityTimeout()
|
old := cm.getInactivityTimeout()
|
||||||
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
||||||
if !initial {
|
if !initial {
|
||||||
cm.l.WithField("oldDuration", old).
|
cm.l.Info("Inactivity timeout has changed",
|
||||||
WithField("newDuration", cm.getInactivityTimeout()).
|
"oldDuration", old,
|
||||||
Info("Inactivity timeout has changed")
|
"newDuration", cm.getInactivityTimeout(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
|
|||||||
old := cm.dropInactive.Load()
|
old := cm.dropInactive.Load()
|
||||||
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
||||||
if !initial {
|
if !initial {
|
||||||
cm.l.WithField("oldBool", old).
|
cm.l.Info("Drop inactive setting has changed",
|
||||||
WithField("newBool", cm.dropInactive.Load()).
|
"oldBool", old,
|
||||||
Info("Drop inactive setting has changed")
|
"newBool", cm.dropInactive.Load(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
var err error
|
var err error
|
||||||
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
cm.l.Error("failed to migrate relay to new hostinfo", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
@@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
cm.l.Error("failed to marshal Control message to migrate relay", "error", err)
|
||||||
} else {
|
} else {
|
||||||
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
cm.l.WithFields(logrus.Fields{
|
cm.l.Info("send CreateRelayRequest",
|
||||||
"relayFrom": req.RelayFromAddr,
|
"relayFrom", req.RelayFromAddr,
|
||||||
"relayTo": req.RelayToAddr,
|
"relayTo", req.RelayToAddr,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": req.ResponderRelayIndex,
|
"responderRelayIndex", req.ResponderRelayIndex,
|
||||||
"vpnAddrs": newhostinfo.vpnAddrs}).
|
"vpnAddrs", newhostinfo.vpnAddrs,
|
||||||
Info("send CreateRelayRequest")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
hostinfo := cm.hostMap.Indexes[localIndex]
|
hostinfo := cm.hostMap.Indexes[localIndex]
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
|
cm.l.Debug("Not found in hostmap", "localIndex", localIndex)
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
// A hostinfo is determined alive if there is incoming traffic
|
// A hostinfo is determined alive if there is incoming traffic
|
||||||
if inTraffic {
|
if inTraffic {
|
||||||
decision := doNothing
|
decision := doNothing
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
"tunnelCheck", m{"state": "alive", "method": "passive"},
|
||||||
Debug("Tunnel status")
|
)
|
||||||
}
|
}
|
||||||
hostinfo.pendingDeletion.Store(false)
|
hostinfo.pendingDeletion.Store(false)
|
||||||
|
|
||||||
@@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
if hostinfo.pendingDeletion.Load() {
|
if hostinfo.pendingDeletion.Load() {
|
||||||
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Info("Tunnel status",
|
||||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
"tunnelCheck", m{"state": "dead", "method": "active"},
|
||||||
Info("Tunnel status")
|
)
|
||||||
|
|
||||||
return deleteTunnel, hostinfo, nil
|
return deleteTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
@@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
||||||
if isInactive {
|
if isInactive {
|
||||||
// Tunnel is inactive, tear it down
|
// Tunnel is inactive, tear it down
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity",
|
||||||
WithField("inactiveDuration", inactiveFor).
|
"inactiveDuration", inactiveFor,
|
||||||
WithField("primary", mainHostInfo).
|
"primary", mainHostInfo,
|
||||||
Info("Dropping tunnel due to inactivity")
|
)
|
||||||
|
|
||||||
return closeTunnel, hostinfo, primary
|
return closeTunnel, hostinfo, primary
|
||||||
}
|
}
|
||||||
@@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
cm.sendPunch(hostinfo)
|
cm.sendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
"tunnelCheck", m{"state": "testing", "method": "active"},
|
||||||
Debug("Tunnel status")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||||
decision = sendTestPacket
|
decision = sendTestPacket
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
|
hostinfo.logger(cm.l).Debug("Hostinfo sadness")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
|
|||||||
return false //cert is still valid! yay!
|
return false //cert is still valid! yay!
|
||||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||||
// Block listed certificates should always be disconnected
|
// Block listed certificates should always be disconnected
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel",
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
"error", err,
|
||||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
"fingerprint", remoteCert.Fingerprint,
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
} else if cm.intf.disconnectInvalid.Load() {
|
} else if cm.intf.disconnectInvalid.Load() {
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel",
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
"error", err,
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
"fingerprint", remoteCert.Fingerprint,
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
} else {
|
} else {
|
||||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||||
@@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
curCrtVersion := curCrt.Version()
|
curCrtVersion := curCrt.Version()
|
||||||
myCrt := cs.getCertificate(curCrtVersion)
|
myCrt := cs.getCertificate(curCrtVersion)
|
||||||
if myCrt == nil {
|
if myCrt == nil {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("version", curCrtVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("reason", "local certificate removed").
|
"version", curCrtVersion,
|
||||||
Info("Re-handshaking with remote")
|
"reason", "local certificate removed",
|
||||||
|
)
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("version", curCrtVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
"version", curCrtVersion,
|
||||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
"peerVersion", peerCrt.Certificate.Version(),
|
||||||
Info("Re-handshaking with remote")
|
"reason", "local certificate version lower than peer, attempting to correct",
|
||||||
|
)
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||||
})
|
})
|
||||||
@@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("reason", "local certificate is not current").
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Info("Re-handshaking with remote")
|
"reason", "local certificate is not current",
|
||||||
|
)
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if curCrtVersion < cs.initiatingVersion {
|
if curCrtVersion < cs.initiatingVersion {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Info("Re-handshaking with remote")
|
"reason", "current cert version < pki.initiatingVersion",
|
||||||
|
)
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/overlaytest"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -63,9 +64,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
@@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -146,9 +147,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
@@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -231,12 +232,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
conf.Settings["tunnels"] = map[string]any{
|
conf.Settings["tunnels"] = map[string]any{
|
||||||
"drop_inactive": true,
|
"drop_inactive": true,
|
||||||
}
|
}
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
assert.True(t, nc.dropInactive.Load())
|
assert.True(t, nc.dropInactive.Load())
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
|
|
||||||
@@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -360,9 +361,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
ifce.disconnectInvalid.Store(true)
|
ifce.disconnectInvalid.Store(true)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
)
|
)
|
||||||
@@ -27,7 +26,7 @@ type ConnectionState struct {
|
|||||||
writeLock sync.Mutex
|
writeLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
||||||
var dhFunc noise.DHFunc
|
var dhFunc noise.DHFunc
|
||||||
switch crt.Curve() {
|
switch crt.Curve() {
|
||||||
case cert.Curve_CURVE25519:
|
case cert.Curve_CURVE25519:
|
||||||
|
|||||||
92
control.go
92
control.go
@@ -2,17 +2,33 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RunState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateUnknown RunState = iota
|
||||||
|
StateReady
|
||||||
|
StateStarted
|
||||||
|
StateStopping
|
||||||
|
StateStopped
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrAlreadyStarted = errors.New("nebula is already started")
|
||||||
|
var ErrAlreadyStopped = errors.New("nebula cannot be restarted")
|
||||||
|
var ErrUnknownState = errors.New("nebula state is invalid")
|
||||||
|
|
||||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
@@ -26,8 +42,11 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
|
stateLock sync.Mutex
|
||||||
|
state RunState
|
||||||
|
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
sshStart func()
|
sshStart func()
|
||||||
@@ -49,10 +68,31 @@ type ControlHostInfo struct {
|
|||||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
// Start actually runs nebula, this is a nonblocking call.
|
||||||
func (c *Control) Start() {
|
// The returned function blocks until nebula has fully stopped and returns the
|
||||||
|
// first fatal reader error (if any). A nil error means nebula shut down
|
||||||
|
// gracefully; a non-nil error means a reader hit an unexpected failure that
|
||||||
|
// triggered the shutdown.
|
||||||
|
func (c *Control) Start() (func() error, error) {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
defer c.stateLock.Unlock()
|
||||||
|
switch c.state {
|
||||||
|
case StateReady:
|
||||||
|
//yay!
|
||||||
|
case StateStopped, StateStopping:
|
||||||
|
return nil, ErrAlreadyStopped
|
||||||
|
case StateStarted:
|
||||||
|
return nil, ErrAlreadyStarted
|
||||||
|
default:
|
||||||
|
return nil, ErrUnknownState
|
||||||
|
}
|
||||||
|
|
||||||
// Activate the interface
|
// Activate the interface
|
||||||
c.f.activate()
|
err := c.f.activate()
|
||||||
|
if err != nil {
|
||||||
|
c.state = StateStopped
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||||
if c.sshStart != nil {
|
if c.sshStart != nil {
|
||||||
@@ -71,25 +111,51 @@ func (c *Control) Start() {
|
|||||||
c.lighthouseStart()
|
c.lighthouseStart()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.f.triggerShutdown = c.Stop
|
||||||
|
|
||||||
// Start reading packets.
|
// Start reading packets.
|
||||||
c.f.run()
|
out, err := c.f.run()
|
||||||
|
if err != nil {
|
||||||
|
c.state = StateStopped
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.state = StateStarted
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Control) State() RunState {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
defer c.stateLock.Unlock()
|
||||||
|
return c.state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) Context() context.Context {
|
func (c *Control) Context() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
||||||
func (c *Control) Stop() {
|
func (c *Control) Stop() {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
if c.state != StateStarted {
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
// We are stopping or stopped already
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = StateStopping
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
|
||||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||||
// being created while we're shutting them all down.
|
// being created while we're shutting them all down.
|
||||||
c.cancel()
|
c.cancel()
|
||||||
|
|
||||||
c.CloseAllTunnels(false)
|
c.CloseAllTunnels(false)
|
||||||
if err := c.f.Close(); err != nil {
|
if err := c.f.Close(); err != nil {
|
||||||
c.l.WithError(err).Error("Close interface failed")
|
c.l.Error("Close interface failed", "error", err)
|
||||||
}
|
}
|
||||||
c.l.Info("Goodbye")
|
c.stateLock.Lock()
|
||||||
|
c.state = StateStopped
|
||||||
|
c.stateLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
@@ -100,7 +166,7 @@ func (c *Control) ShutdownBlock() {
|
|||||||
|
|
||||||
rawSig := <-sigChan
|
rawSig := <-sigChan
|
||||||
sig := rawSig.String()
|
sig := rawSig.String()
|
||||||
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
|
c.l.Info("Caught signal, shutting down", "signal", sig)
|
||||||
c.Stop()
|
c.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -237,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
|||||||
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
c.f.closeTunnel(h)
|
c.f.closeTunnel(h)
|
||||||
|
|
||||||
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
|
c.l.Debug("Sending close tunnel message",
|
||||||
Debug("Sending close tunnel message")
|
"vpnAddrs", h.vpnAddrs,
|
||||||
|
"udpAddr", h.remote,
|
||||||
|
)
|
||||||
closed++
|
closed++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -79,10 +78,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
}, &Interface{})
|
}, &Interface{})
|
||||||
|
|
||||||
c := Control{
|
c := Control{
|
||||||
|
state: StateReady,
|
||||||
f: &Interface{
|
f: &Interface{
|
||||||
hostMap: hm,
|
hostMap: hm,
|
||||||
},
|
},
|
||||||
l: logrus.New(),
|
l: test.NewLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
|
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build e2e_testing
|
//go:build e2e_testing
|
||||||
// +build e2e_testing
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
22
dist/wireshark/nebula.lua
vendored
22
dist/wireshark/nebula.lua
vendored
@@ -84,30 +84,24 @@ end
|
|||||||
|
|
||||||
function nebula.prefs_changed()
|
function nebula.prefs_changed()
|
||||||
if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then
|
if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then
|
||||||
-- Nothing changed, bail
|
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Remove our old dissector
|
-- Remove all existing registrations
|
||||||
DissectorTable.get("udp.port"):remove_all(nebula)
|
DissectorTable.get("udp.port"):remove_all(nebula)
|
||||||
|
|
||||||
if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then
|
if nebula.prefs.all_ports then
|
||||||
default_settings.all_port = nebula.prefs.all_ports
|
-- Register on every port for hole punch capture
|
||||||
|
|
||||||
for i=0, 65535 do
|
for i=0, 65535 do
|
||||||
DissectorTable.get("udp.port"):add(i, nebula)
|
DissectorTable.get("udp.port"):add(i, nebula)
|
||||||
end
|
end
|
||||||
|
else
|
||||||
-- no need to establish again on specific ports
|
-- Register on the configured port only
|
||||||
return
|
DissectorTable.get("udp.port"):add(nebula.prefs.port, nebula)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
default_settings.all_ports = nebula.prefs.all_ports
|
||||||
if default_settings.all_ports ~= nebula.prefs.all_ports then
|
default_settings.port = nebula.prefs.port
|
||||||
-- Add our new port dissector
|
|
||||||
default_settings.port = nebula.prefs.port
|
|
||||||
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
|
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
|
||||||
|
|||||||
299
dns_server.go
299
dns_server.go
@@ -1,63 +1,249 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This whole thing should be rewritten to use context
|
type dnsServer struct {
|
||||||
|
|
||||||
var dnsR *dnsRecords
|
|
||||||
var dnsServer *dns.Server
|
|
||||||
var dnsAddr string
|
|
||||||
|
|
||||||
type dnsRecords struct {
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
ctx context.Context
|
||||||
dnsMap4 map[string]netip.Addr
|
dnsMap4 map[string]netip.Addr
|
||||||
dnsMap6 map[string]netip.Addr
|
dnsMap6 map[string]netip.Addr
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
myVpnAddrsTable *bart.Lite
|
myVpnAddrsTable *bart.Lite
|
||||||
|
|
||||||
|
mux *dns.ServeMux
|
||||||
|
|
||||||
|
// enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`.
|
||||||
|
// Start, Add, and reload consult it so callers don't need to know the
|
||||||
|
// gating rules. When it toggles off via reload, accumulated records are
|
||||||
|
// cleared so a later re-enable starts with a fresh map populated from
|
||||||
|
// new handshakes.
|
||||||
|
enabled atomic.Bool
|
||||||
|
|
||||||
|
serverMu sync.Mutex
|
||||||
|
server *dns.Server
|
||||||
|
// started is closed once `server` has finished binding (or after
|
||||||
|
// ListenAndServe returns on a bind failure). Stop waits on it before
|
||||||
|
// calling Shutdown to avoid the miekg/dns "server not started" race
|
||||||
|
// where a Shutdown that arrives before bind completes is silently
|
||||||
|
// ignored, leaving the listener running forever.
|
||||||
|
started chan struct{}
|
||||||
|
addr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
// newDnsServerFromConfig builds a dnsServer, applies the initial config, and
|
||||||
return &dnsRecords{
|
// registers a reload callback. The reload callback is registered before the
|
||||||
|
// initial config is applied, so a SIGHUP can later enable, fix, or disable
|
||||||
|
// DNS even if the initial application failed.
|
||||||
|
//
|
||||||
|
// The dnsServer internally gates on `lighthouse.serve_dns &&
|
||||||
|
// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally,
|
||||||
|
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
|
||||||
|
// watcher that tears the listener down on nebula shutdown. The returned
|
||||||
|
// pointer is always non-nil, even on error.
|
||||||
|
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
|
||||||
|
ds := &dnsServer{
|
||||||
l: l,
|
l: l,
|
||||||
|
ctx: ctx,
|
||||||
dnsMap4: make(map[string]netip.Addr),
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
dnsMap6: make(map[string]netip.Addr),
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||||
}
|
}
|
||||||
|
ds.mux = dns.NewServeMux()
|
||||||
|
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
||||||
|
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := ds.reload(c, false); err != nil {
|
||||||
|
ds.l.Error("Failed to reload DNS responder from config", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := ds.reload(c, true); err != nil {
|
||||||
|
return ds, err
|
||||||
|
}
|
||||||
|
return ds, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
// reload applies the latest config and reconciles the running state with it:
|
||||||
|
// - enabled toggled on -> spawn a runner
|
||||||
|
// - enabled toggled off -> stop the runner
|
||||||
|
// - listen address changed (while running) -> restart on the new address
|
||||||
|
// - everything else -> no-op
|
||||||
|
//
|
||||||
|
// On the initial call it only records configuration; Control.Start is what
|
||||||
|
// launches the first runner via dnsStart.
|
||||||
|
func (d *dnsServer) reload(c *config.C, initial bool) error {
|
||||||
|
wantsDns := c.GetBool("lighthouse.serve_dns", false)
|
||||||
|
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||||
|
enabled := wantsDns && amLighthouse
|
||||||
|
newAddr := getDnsServerAddr(c)
|
||||||
|
|
||||||
|
d.serverMu.Lock()
|
||||||
|
running := d.server
|
||||||
|
runningStarted := d.started
|
||||||
|
sameAddr := d.addr == newAddr
|
||||||
|
d.addr = newAddr
|
||||||
|
d.enabled.Store(enabled)
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
|
||||||
|
if initial {
|
||||||
|
if wantsDns && !amLighthouse {
|
||||||
|
d.l.Warn("DNS server refusing to run because this host is not a lighthouse.")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
if running != nil {
|
||||||
|
d.Stop()
|
||||||
|
}
|
||||||
|
// Drop any records that accumulated while enabled; a later re-enable
|
||||||
|
// will repopulate from fresh handshakes.
|
||||||
|
d.clearRecords()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if running == nil {
|
||||||
|
// Was disabled (or never started); bring it up now.
|
||||||
|
go d.Start()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if sameAddr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d.shutdownServer(running, runningStarted, "reload")
|
||||||
|
// Old Start goroutine has now exited; bring up a fresh listener on the
|
||||||
|
// new address.
|
||||||
|
go d.Start()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdownServer waits for the server to finish binding (so Shutdown actually
|
||||||
|
// stops it rather than no-oping) and then shuts it down.
|
||||||
|
func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) {
|
||||||
|
if srv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if started != nil {
|
||||||
|
<-started
|
||||||
|
}
|
||||||
|
if err := srv.Shutdown(); err != nil {
|
||||||
|
d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start binds and serves the DNS responder. Blocks until Stop is called or
|
||||||
|
// the listener errors. Safe to call when DNS is disabled (returns
|
||||||
|
// immediately). This is what Control.dnsStart points at.
|
||||||
|
//
|
||||||
|
// Must be invoked after the tun device is active so that lighthouse.dns.host
|
||||||
|
// may bind to a nebula IP.
|
||||||
|
func (d *dnsServer) Start() {
|
||||||
|
if !d.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
started := make(chan struct{})
|
||||||
|
d.serverMu.Lock()
|
||||||
|
if d.ctx.Err() != nil {
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addr := d.addr
|
||||||
|
server := &dns.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Net: "udp",
|
||||||
|
Handler: d.mux,
|
||||||
|
NotifyStartedFunc: func() { close(started) },
|
||||||
|
}
|
||||||
|
d.server = server
|
||||||
|
d.started = started
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
|
||||||
|
// Per-invocation ctx watcher. Exits when Start does, so we don't leak a
|
||||||
|
// watcher per reload-driven restart.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
d.shutdownServer(server, started, "shutdown")
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
d.l.Info("Starting DNS responder", "dnsListener", addr)
|
||||||
|
err := server.ListenAndServe()
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
// If the listener never bound (bind error) NotifyStartedFunc never fires,
|
||||||
|
// so close started here to release any Stop caller waiting on it.
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
default:
|
||||||
|
close(started)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
d.l.Warn("Failed to run the DNS responder", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop shuts down the active server, if any. Idempotent.
|
||||||
|
func (d *dnsServer) Stop() {
|
||||||
|
d.serverMu.Lock()
|
||||||
|
srv := d.server
|
||||||
|
started := d.started
|
||||||
|
d.server = nil
|
||||||
|
d.started = nil
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
d.shutdownServer(srv, started, "stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query returns the address for the given name and query type. The second
|
||||||
|
// return value reports whether the name is known at all (in either A or AAAA),
|
||||||
|
// which lets callers distinguish NODATA from NXDOMAIN.
|
||||||
|
func (d *dnsServer) Query(q uint16, data string) (netip.Addr, bool) {
|
||||||
data = strings.ToLower(data)
|
data = strings.ToLower(data)
|
||||||
d.RLock()
|
d.RLock()
|
||||||
defer d.RUnlock()
|
defer d.RUnlock()
|
||||||
|
addr4, haveV4 := d.dnsMap4[data]
|
||||||
|
addr6, haveV6 := d.dnsMap6[data]
|
||||||
|
nameExists := haveV4 || haveV6
|
||||||
switch q {
|
switch q {
|
||||||
case dns.TypeA:
|
case dns.TypeA:
|
||||||
if r, ok := d.dnsMap4[data]; ok {
|
if haveV4 {
|
||||||
return r
|
return addr4, nameExists
|
||||||
}
|
}
|
||||||
case dns.TypeAAAA:
|
case dns.TypeAAAA:
|
||||||
if r, ok := d.dnsMap6[data]; ok {
|
if haveV6 {
|
||||||
return r
|
return addr6, nameExists
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return netip.Addr{}
|
return netip.Addr{}, nameExists
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) QueryCert(data string) string {
|
func (d *dnsServer) QueryCert(data string) string {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
ip, err := netip.ParseAddr(data[:len(data)-1])
|
ip, err := netip.ParseAddr(data[:len(data)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
@@ -80,8 +266,19 @@ func (d *dnsRecords) QueryCert(data string) string {
|
|||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clearRecords drops all DNS records.
|
||||||
|
func (d *dnsServer) clearRecords() {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
clear(d.dnsMap4)
|
||||||
|
clear(d.dnsMap6)
|
||||||
|
}
|
||||||
|
|
||||||
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
||||||
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
func (d *dnsServer) Add(host string, addresses []netip.Addr) {
|
||||||
|
if !d.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
host = strings.ToLower(host)
|
host = strings.ToLower(host)
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
@@ -101,7 +298,7 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
|
||||||
a, _, _ := net.SplitHostPort(addr)
|
a, _, _ := net.SplitHostPort(addr)
|
||||||
b, err := netip.ParseAddr(a)
|
b, err := netip.ParseAddr(a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -116,13 +313,24 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
|||||||
return d.myVpnAddrsTable.Contains(b)
|
return d.myVpnAddrsTable.Contains(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||||
|
debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug)
|
||||||
|
// Per RFC 2308 §2.2, a name that exists but has no record of the requested
|
||||||
|
// type must be answered with NOERROR and an empty answer section (NODATA),
|
||||||
|
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
|
||||||
|
// exist at all.
|
||||||
|
anyNameExists := false
|
||||||
for _, q := range m.Question {
|
for _, q := range m.Question {
|
||||||
switch q.Qtype {
|
switch q.Qtype {
|
||||||
case dns.TypeA, dns.TypeAAAA:
|
case dns.TypeA, dns.TypeAAAA:
|
||||||
qType := dns.TypeToString[q.Qtype]
|
qType := dns.TypeToString[q.Qtype]
|
||||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
if debugEnabled {
|
||||||
ip := d.Query(q.Qtype, q.Name)
|
d.l.Debug("DNS query", "type", qType, "name", q.Name)
|
||||||
|
}
|
||||||
|
ip, nameExists := d.Query(q.Qtype, q.Name)
|
||||||
|
if nameExists {
|
||||||
|
anyNameExists = true
|
||||||
|
}
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -134,7 +342,9 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
d.l.Debugf("Query for TXT %s", q.Name)
|
if debugEnabled {
|
||||||
|
d.l.Debug("DNS query", "type", "TXT", "name", q.Name)
|
||||||
|
}
|
||||||
ip := d.QueryCert(q.Name)
|
ip := d.QueryCert(q.Name)
|
||||||
if ip != "" {
|
if ip != "" {
|
||||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||||
@@ -145,12 +355,12 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.Answer) == 0 {
|
if len(m.Answer) == 0 && !anyNameExists {
|
||||||
m.Rcode = dns.RcodeNameError
|
m.Rcode = dns.RcodeNameError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
m.Compress = false
|
m.Compress = false
|
||||||
@@ -163,21 +373,6 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
|
||||||
dnsR = newDnsRecords(l, cs, hostMap)
|
|
||||||
|
|
||||||
// attach request handler func
|
|
||||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
|
||||||
reloadDns(l, c)
|
|
||||||
})
|
|
||||||
|
|
||||||
return func() {
|
|
||||||
startDns(l, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDnsServerAddr(c *config.C) string {
|
func getDnsServerAddr(c *config.C) string {
|
||||||
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
|
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
|
||||||
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
|
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
|
||||||
@@ -186,25 +381,3 @@ func getDnsServerAddr(c *config.C) string {
|
|||||||
}
|
}
|
||||||
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func startDns(l *logrus.Logger, c *config.C) {
|
|
||||||
dnsAddr = getDnsServerAddr(c)
|
|
||||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
|
||||||
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
|
||||||
err := dnsServer.ListenAndServe()
|
|
||||||
defer dnsServer.Shutdown()
|
|
||||||
if err != nil {
|
|
||||||
l.Errorf("Failed to start server: %s\n ", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func reloadDns(l *logrus.Logger, c *config.C) {
|
|
||||||
if dnsAddr == getDnsServerAddr(c) {
|
|
||||||
l.Debug("No DNS server config change detected")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Debug("Restarting DNS server")
|
|
||||||
dnsServer.Shutdown()
|
|
||||||
go startDns(l, c)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,19 +1,43 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type stubDNSWriter struct{}
|
||||||
|
|
||||||
|
func (stubDNSWriter) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||||
|
func (stubDNSWriter) RemoteAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5353}
|
||||||
|
}
|
||||||
|
func (stubDNSWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
func (stubDNSWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||||
|
func (stubDNSWriter) Close() error { return nil }
|
||||||
|
func (stubDNSWriter) TsigStatus() error { return nil }
|
||||||
|
func (stubDNSWriter) TsigTimersOnly(bool) {}
|
||||||
|
func (stubDNSWriter) Hijack() {}
|
||||||
|
|
||||||
func TestParsequery(t *testing.T) {
|
func TestParsequery(t *testing.T) {
|
||||||
l := logrus.New()
|
l := slog.New(slog.DiscardHandler)
|
||||||
hostMap := &HostMap{}
|
hostMap := &HostMap{}
|
||||||
ds := newDnsRecords(l, &CertState{}, hostMap)
|
ds := &dnsServer{
|
||||||
|
l: l,
|
||||||
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
|
hostMap: hostMap,
|
||||||
|
}
|
||||||
|
ds.enabled.Store(true)
|
||||||
addrs := []netip.Addr{
|
addrs := []netip.Addr{
|
||||||
netip.MustParseAddr("1.2.3.4"),
|
netip.MustParseAddr("1.2.3.4"),
|
||||||
netip.MustParseAddr("1.2.3.5"),
|
netip.MustParseAddr("1.2.3.5"),
|
||||||
@@ -21,18 +45,56 @@ func TestParsequery(t *testing.T) {
|
|||||||
netip.MustParseAddr("fd01::25"),
|
netip.MustParseAddr("fd01::25"),
|
||||||
}
|
}
|
||||||
ds.Add("test.com.com", addrs)
|
ds.Add("test.com.com", addrs)
|
||||||
|
ds.Add("v4only.com.com", []netip.Addr{netip.MustParseAddr("1.2.3.6")})
|
||||||
|
ds.Add("v6only.com.com", []netip.Addr{netip.MustParseAddr("fd01::26")})
|
||||||
|
|
||||||
m := &dns.Msg{}
|
m := &dns.Msg{}
|
||||||
m.SetQuestion("test.com.com", dns.TypeA)
|
m.SetQuestion("test.com.com", dns.TypeA)
|
||||||
ds.parseQuery(m, nil)
|
ds.parseQuery(m, nil)
|
||||||
assert.NotNil(t, m.Answer)
|
assert.NotNil(t, m.Answer)
|
||||||
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
|
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
||||||
|
|
||||||
m = &dns.Msg{}
|
m = &dns.Msg{}
|
||||||
m.SetQuestion("test.com.com", dns.TypeAAAA)
|
m.SetQuestion("test.com.com", dns.TypeAAAA)
|
||||||
ds.parseQuery(m, nil)
|
ds.parseQuery(m, nil)
|
||||||
assert.NotNil(t, m.Answer)
|
assert.NotNil(t, m.Answer)
|
||||||
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
|
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
||||||
|
|
||||||
|
// A known name with no record of the requested type should return NODATA
|
||||||
|
// (NOERROR with empty answer), not NXDOMAIN.
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.SetQuestion("v4only.com.com", dns.TypeAAAA)
|
||||||
|
ds.parseQuery(m, nil)
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
||||||
|
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.SetQuestion("v6only.com.com", dns.TypeA)
|
||||||
|
ds.parseQuery(m, nil)
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
||||||
|
|
||||||
|
// An unknown name should still return NXDOMAIN.
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.SetQuestion("unknown.com.com", dns.TypeA)
|
||||||
|
ds.parseQuery(m, nil)
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, m.Rcode)
|
||||||
|
|
||||||
|
// short lookups should not fail
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.Question = []dns.Question{{Name: "", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
|
||||||
|
ds.parseQuery(m, stubDNSWriter{})
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, m.Rcode)
|
||||||
|
|
||||||
|
m = &dns.Msg{}
|
||||||
|
m.Question = []dns.Question{{Name: ".", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
|
||||||
|
ds.parseQuery(m, stubDNSWriter{})
|
||||||
|
assert.Empty(t, m.Answer)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, m.Rcode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getDnsServerAddr(t *testing.T) {
|
func Test_getDnsServerAddr(t *testing.T) {
|
||||||
@@ -71,3 +133,208 @@ func Test_getDnsServerAddr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
|
||||||
|
t.Helper()
|
||||||
|
sl := slog.New(slog.DiscardHandler)
|
||||||
|
ds := &dnsServer{
|
||||||
|
l: sl,
|
||||||
|
ctx: context.Background(),
|
||||||
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
|
hostMap: &HostMap{},
|
||||||
|
}
|
||||||
|
ds.mux = dns.NewServeMux()
|
||||||
|
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
||||||
|
return ds, config.NewC(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
|
||||||
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
|
"am_lighthouse": amLighthouse,
|
||||||
|
"serve_dns": serveDns,
|
||||||
|
"dns": map[string]any{
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_initial_disabled(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", true, false)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
assert.False(t, ds.enabled.Load())
|
||||||
|
assert.Equal(t, "127.0.0.1:0", ds.addr)
|
||||||
|
assert.Nil(t, ds.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_initial_enabled(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
assert.True(t, ds.enabled.Load())
|
||||||
|
assert.Equal(t, "127.0.0.1:0", ds.addr)
|
||||||
|
// initial never starts a runner; that's Control.Start's job
|
||||||
|
assert.Nil(t, ds.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", false, true)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
// Wants DNS but isn't a lighthouse: gated off, no runner.
|
||||||
|
assert.False(t, ds.enabled.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_sameAddr_noOp(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
// No server running yet, no addr change. Reload should not spawn anything.
|
||||||
|
require.NoError(t, ds.reload(c, false))
|
||||||
|
assert.True(t, ds.enabled.Load())
|
||||||
|
assert.Nil(t, ds.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_StartStop_lifecycle(t *testing.T) {
|
||||||
|
// Bind to a real (random) UDP port so we exercise the actual
|
||||||
|
// ListenAndServe + Shutdown plumbing including the started-chan race fix.
|
||||||
|
port := freeUDPPort(t)
|
||||||
|
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", port, true, true)
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Start()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitFor(t, func() bool {
|
||||||
|
ds.serverMu.Lock()
|
||||||
|
started := ds.started
|
||||||
|
ds.serverMu.Unlock()
|
||||||
|
if started == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ds.Stop()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Start did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
|
||||||
|
// Stop called immediately after Start should not deadlock even if bind
|
||||||
|
// hasn't completed yet. This exercises the started-chan close-on-bind-fail
|
||||||
|
// path: by binding to an obviously bad port (privileged) we get a fast
|
||||||
|
// bind error before NotifyStartedFunc fires.
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
// Use a port that should fail to bind (negative would be invalid, use a
|
||||||
|
// host that won't resolve to ensure listenUDP fails quickly).
|
||||||
|
setDnsConfig(c, "256.256.256.256", "53", true, true)
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Start()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give Start a moment to attempt the bind and fail.
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Bind failed and Start returned; Stop should be a no-op.
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("Start did not return after a bad bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
stopped := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Stop()
|
||||||
|
close(stopped)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-stopped:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("Stop hung after a failed bind")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
|
||||||
|
port := freeUDPPort(t)
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", port, true, true)
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
|
||||||
|
startReturned := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Start()
|
||||||
|
close(startReturned)
|
||||||
|
}()
|
||||||
|
waitForBind(t, ds)
|
||||||
|
|
||||||
|
// Toggle serve_dns off; reload should shut the running server down.
|
||||||
|
setDnsConfig(c, "127.0.0.1", port, true, false)
|
||||||
|
require.NoError(t, ds.reload(c, false))
|
||||||
|
select {
|
||||||
|
case <-startReturned:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Start did not return after reload disabled DNS")
|
||||||
|
}
|
||||||
|
assert.False(t, ds.enabled.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func freeUDPPort(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
port := conn.LocalAddr().(*net.UDPAddr).Port
|
||||||
|
require.NoError(t, conn.Close())
|
||||||
|
return strconv.Itoa(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForBind(t *testing.T, ds *dnsServer) {
|
||||||
|
t.Helper()
|
||||||
|
waitFor(t, func() bool {
|
||||||
|
ds.serverMu.Lock()
|
||||||
|
started := ds.started
|
||||||
|
ds.serverMu.Unlock()
|
||||||
|
if started == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitFor(t *testing.T, cond func() bool) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if cond() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatal("timed out waiting for condition")
|
||||||
|
}
|
||||||
|
|||||||
565
e2e/handshake_manager_test.go
Normal file
565
e2e/handshake_manager_test.go
Normal file
@@ -0,0 +1,565 @@
|
|||||||
|
//go:build e2e_testing
|
||||||
|
// +build e2e_testing
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/cert_test"
|
||||||
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// makeHandshakePacket creates a handshake packet with the given parameters.
|
||||||
|
func makeHandshakePacket(from, to netip.AddrPort, subtype header.MessageSubType, remoteIndex uint32, counter uint64) *udp.Packet {
|
||||||
|
data := make([]byte, 200)
|
||||||
|
header.Encode(data, header.Version, header.Handshake, subtype, remoteIndex, counter)
|
||||||
|
for i := header.Len; i < len(data); i++ {
|
||||||
|
data[i] = byte(i)
|
||||||
|
}
|
||||||
|
return &udp.Packet{To: to, From: from, Data: data}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeRetransmitDuplicate(t *testing.T) {
|
||||||
|
// Verify the responder correctly handles receiving the same msg1 multiple times
|
||||||
|
// (retransmission). The duplicate goes through CheckAndComplete -> ErrAlreadySeen
|
||||||
|
// and the cached response is resent.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake from me to them")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
|
||||||
|
|
||||||
|
t.Log("Grab my msg1")
|
||||||
|
msg1 := myControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Inject msg1 into them, first time")
|
||||||
|
theirControl.InjectUDPPacket(msg1)
|
||||||
|
_ = theirControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Inject the SAME msg1 again, tests ErrAlreadySeen path")
|
||||||
|
theirControl.InjectUDPPacket(msg1)
|
||||||
|
resp2 := theirControl.GetFromUDP(true)
|
||||||
|
assert.NotNil(t, resp2, "should get cached response on duplicate msg1")
|
||||||
|
|
||||||
|
t.Log("Complete handshake with cached response")
|
||||||
|
myControl.InjectUDPPacket(resp2)
|
||||||
|
myControl.WaitForType(1, 0, theirControl)
|
||||||
|
|
||||||
|
t.Log("Drain cached packet and verify tunnel works")
|
||||||
|
cachedPacket := theirControl.GetFromTun(true)
|
||||||
|
assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Verify only one tunnel exists on each side")
|
||||||
|
assert.Len(t, myControl.ListHostmapHosts(false), 1)
|
||||||
|
assert.Len(t, theirControl.ListHostmapHosts(false), 1)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeTruncatedPacketRecovery(t *testing.T) {
|
||||||
|
// Verify that a truncated handshake packet is ignored and the real
|
||||||
|
// packet can still complete the handshake.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
|
||||||
|
|
||||||
|
t.Log("Get msg1 and deliver to responder")
|
||||||
|
msg1 := myControl.GetFromUDP(true)
|
||||||
|
theirControl.InjectUDPPacket(msg1)
|
||||||
|
|
||||||
|
t.Log("Get the real response")
|
||||||
|
realResp := theirControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Truncate the response and inject, should be ignored")
|
||||||
|
truncResp := realResp.Copy()
|
||||||
|
truncResp.Data = truncResp.Data[:header.Len]
|
||||||
|
myControl.InjectUDPPacket(truncResp)
|
||||||
|
|
||||||
|
t.Log("Verify pending handshake survived the truncated packet")
|
||||||
|
assert.NotEmpty(t, myControl.ListHostmapHosts(true), "pending handshake should still exist")
|
||||||
|
|
||||||
|
t.Log("Inject real response, should complete handshake")
|
||||||
|
myControl.InjectUDPPacket(realResp)
|
||||||
|
myControl.WaitForType(1, 0, theirControl)
|
||||||
|
|
||||||
|
t.Log("Drain and verify tunnel")
|
||||||
|
cachedPacket := theirControl.GetFromTun(true)
|
||||||
|
assertUdpPacket(t, []byte("Hi"), cachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeOrphanedMsg2Dropped(t *testing.T) {
|
||||||
|
// A msg2 arriving with no matching pending index should be silently dropped
|
||||||
|
// with no response sent and no state changes.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Complete a normal handshake")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
|
||||||
|
r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Record hostmap state")
|
||||||
|
myIndexes := len(myControl.ListHostmapIndexes(false))
|
||||||
|
|
||||||
|
t.Log("Inject a fake msg2 with unknown RemoteIndex")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0xDEADBEEF, 2))
|
||||||
|
|
||||||
|
t.Log("Verify no new indexes created")
|
||||||
|
assert.Equal(t, myIndexes, len(myControl.ListHostmapIndexes(false)))
|
||||||
|
|
||||||
|
t.Log("Verify no UDP response was sent")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false), "should not send a response to orphaned msg2")
|
||||||
|
|
||||||
|
t.Log("Verify existing tunnel still works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeUnknownMessageCounter(t *testing.T) {
|
||||||
|
// A handshake packet with an unexpected message counter should be silently
|
||||||
|
// dropped with no side effects and no UDP response.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Inject handshake with MessageCounter=3")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 3))
|
||||||
|
|
||||||
|
t.Log("Inject handshake with MessageCounter=99")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 99))
|
||||||
|
|
||||||
|
t.Log("Verify no tunnels or pending handshakes")
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
|
||||||
|
t.Log("Verify no UDP response was sent")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeUnknownSubtype(t *testing.T) {
|
||||||
|
// A handshake packet with an unknown subtype should be silently dropped.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Inject handshake with unknown subtype 99")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.MessageSubType(99), 0, 1))
|
||||||
|
|
||||||
|
t.Log("Verify no tunnels or pending handshakes")
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
|
||||||
|
t.Log("Verify no UDP response was sent")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeLateResponse(t *testing.T) {
|
||||||
|
// After a handshake times out, a late response should be silently ignored
|
||||||
|
// with no new tunnels created.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{
|
||||||
|
"handshakes": m{
|
||||||
|
"try_interval": "200ms",
|
||||||
|
"retries": 2,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake from me")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
|
||||||
|
|
||||||
|
t.Log("Grab msg1 but don't deliver")
|
||||||
|
msg1 := myControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Wait for handshake to time out")
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
myControl.GetFromUDP(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Confirm no pending handshakes remain")
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
|
||||||
|
t.Log("Deliver old msg1 to them, they create a tunnel")
|
||||||
|
theirControl.InjectUDPPacket(msg1)
|
||||||
|
resp := theirControl.GetFromUDP(true)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
|
||||||
|
t.Log("Inject late response into me, should be ignored")
|
||||||
|
myControl.InjectUDPPacket(resp)
|
||||||
|
|
||||||
|
t.Log("No tunnel should exist on my side")
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeSelfConnectionRejected(t *testing.T) {
|
||||||
|
// Verify that a node rejects a handshake containing its own VPN IP in the
|
||||||
|
// peer cert. We do this by sending the initiator's own msg1 back to itself.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
|
||||||
|
// Need a lighthouse entry to trigger a handshake
|
||||||
|
myControl.InjectLightHouseAddr(netip.MustParseAddr("10.128.0.2"), netip.MustParseAddrPort("10.0.0.2:4242"))
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake from me")
|
||||||
|
myControl.InjectTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
|
||||||
|
msg1 := myControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Drain any handshake retransmits before injecting")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
for myControl.GetFromUDP(false) != nil {
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Feed my own msg1 back to me as if it came from someone else")
|
||||||
|
selfMsg := msg1.Copy()
|
||||||
|
selfMsg.From = netip.MustParseAddrPort("10.0.0.99:4242")
|
||||||
|
selfMsg.To = myUdpAddr
|
||||||
|
myControl.InjectUDPPacket(selfMsg)
|
||||||
|
|
||||||
|
t.Log("Verify no response was sent (self-connection rejected)")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
// Drain any further retransmits from the original handshake, then check
|
||||||
|
// that none of them are a handshake response (MessageCounter=2)
|
||||||
|
h := &header.H{}
|
||||||
|
for {
|
||||||
|
p := myControl.GetFromUDP(false)
|
||||||
|
if p == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_ = h.Parse(p.Data)
|
||||||
|
assert.NotEqual(t, uint64(2), h.MessageCounter,
|
||||||
|
"should not send a stage 2 response to self-connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Verify no tunnel to myself was created")
|
||||||
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeMessageCounter0Dropped(t *testing.T) {
|
||||||
|
// MessageCounter=0 is not a valid handshake message and should be dropped.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, _, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
_, _, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
|
||||||
|
t.Log("Inject handshake with MessageCounter=0")
|
||||||
|
myControl.InjectUDPPacket(makeHandshakePacket(theirUdpAddr, myUdpAddr, header.HandshakeIXPSK0, 0, 0))
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false))
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeRemoteAllowList(t *testing.T) {
|
||||||
|
// Verify that a handshake from a blocked underlay IP is dropped with no
|
||||||
|
// response and no state changes. Then verify the same packet from an
|
||||||
|
// allowed IP succeeds.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{
|
||||||
|
"lighthouse": m{
|
||||||
|
"remote_allow_list": m{
|
||||||
|
"10.0.0.0/8": true,
|
||||||
|
"0.0.0.0/0": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake from them")
|
||||||
|
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi"))
|
||||||
|
msg1 := theirControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Rewrite the source to a blocked IP and inject")
|
||||||
|
blockedMsg := msg1.Copy()
|
||||||
|
blockedMsg.From = netip.MustParseAddrPort("192.168.1.1:4242")
|
||||||
|
myControl.InjectUDPPacket(blockedMsg)
|
||||||
|
|
||||||
|
t.Log("Verify no tunnel, no pending, no response from blocked source")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(false))
|
||||||
|
assert.Empty(t, myControl.ListHostmapHosts(true))
|
||||||
|
assert.Nil(t, myControl.GetFromUDP(false), "should not respond to blocked source")
|
||||||
|
|
||||||
|
t.Log("Now inject the real packet from the allowed source")
|
||||||
|
myControl.InjectUDPPacket(msg1)
|
||||||
|
|
||||||
|
t.Log("Verify handshake completes from allowed source")
|
||||||
|
resp := myControl.GetFromUDP(true)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
theirControl.InjectUDPPacket(resp)
|
||||||
|
theirControl.WaitForType(1, 0, myControl)
|
||||||
|
|
||||||
|
t.Log("Drain cached packet and verify tunnel works")
|
||||||
|
cachedPacket := myControl.GetFromTun(true)
|
||||||
|
assertUdpPacket(t, []byte("Hi"), cachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) {
|
||||||
|
// When a duplicate msg1 arrives via ErrAlreadySeen, verify the tunnel
|
||||||
|
// remains functional and hostmap index count is stable.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
t.Log("Complete a normal handshake via the router")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))
|
||||||
|
r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Record hostmap state")
|
||||||
|
theirIndexes := len(theirControl.ListHostmapIndexes(false))
|
||||||
|
hi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, hi)
|
||||||
|
originalRemote := hi.CurrentRemote
|
||||||
|
|
||||||
|
t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam"))
|
||||||
|
r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
|
||||||
|
t.Log("Verify tunnel still works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Verify remote is still valid and index count is stable")
|
||||||
|
hi2 := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, hi2)
|
||||||
|
assert.Equal(t, originalRemote, hi2.CurrentRemote)
|
||||||
|
assert.Equal(t, theirIndexes, len(theirControl.ListHostmapIndexes(false)),
|
||||||
|
"no extra indexes should be created from ErrAlreadySeen")
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeWrongResponderPacketStore(t *testing.T) {
|
||||||
|
// Verify that when the wrong host responds, the cached packets are
|
||||||
|
// transferred to the new handshake, the evil tunnel is closed, evil's
|
||||||
|
// address is blocked, and the correct tunnel is eventually established.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
|
||||||
|
evilControl, evilVpnIpNet, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl, evilControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
evilControl.Start()
|
||||||
|
|
||||||
|
t.Log("Send multiple packets to them (cached during handshake)")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1"))
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2"))
|
||||||
|
|
||||||
|
t.Log("Route until evil tunnel is closed")
|
||||||
|
h := &header.H{}
|
||||||
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
|
if err := h.Parse(p.Data); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if h.Type == header.CloseTunnel && p.To == evilUdpAddr {
|
||||||
|
return router.RouteAndExit
|
||||||
|
}
|
||||||
|
return router.KeepRouting
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Log("Verify evil's address is blocked in the new pending handshake")
|
||||||
|
pendingHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
|
||||||
|
if pendingHI != nil {
|
||||||
|
assert.NotContains(t, pendingHI.RemoteAddrs, evilUdpAddr,
|
||||||
|
"evil's address should be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Inject correct lighthouse addr for them")
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
t.Log("Route until cached packets arrive at the real them")
|
||||||
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
assert.NotNil(t, p, "a cached packet should be delivered to the correct host")
|
||||||
|
|
||||||
|
t.Log("Verify the correct host has a tunnel")
|
||||||
|
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||||
|
|
||||||
|
t.Log("Verify no hostinfo artifacts from evil remain")
|
||||||
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), true),
|
||||||
|
"no pending hostinfo for evil")
|
||||||
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIpNet[0].Addr(), false),
|
||||||
|
"no main hostinfo for evil")
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
evilControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshakeRelayComplete(t *testing.T) {
|
||||||
|
// Verify that a relay handshake completes correctly and relay state is
|
||||||
|
// properly maintained on all three nodes.
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
myControl.Start()
|
||||||
|
relayControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger handshake via relay")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay"))
|
||||||
|
|
||||||
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
|
t.Log("Verify bidirectional tunnel via relay")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
t.Log("Verify relay state on my side shows relay-to-me")
|
||||||
|
myHI := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, myHI)
|
||||||
|
assert.NotEmpty(t, myHI.CurrentRelaysToMe, "should have relay-to-me for them")
|
||||||
|
|
||||||
|
t.Log("Verify relay state on their side shows relay-to-me")
|
||||||
|
theirHI := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, theirHI)
|
||||||
|
assert.NotEmpty(t, theirHI.CurrentRelaysToMe, "should have relay-to-me for me")
|
||||||
|
|
||||||
|
t.Log("Verify relay node shows through-me relays")
|
||||||
|
relayHI := relayControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
assert.NotNil(t, relayHI)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
relayControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: Relay V1 cert + IPv6 rejection is not tested here because
|
||||||
|
// InjectTunUDPPacket from a V4 node to a V6 address panics in the test
|
||||||
|
// framework. The check is in handshake_manager.go handleOutbound relay
|
||||||
|
// logic (lines ~304-313): if the relay host has a V1 cert and either
|
||||||
|
// address is IPv6, the relay is skipped.
|
||||||
|
|
||||||
|
// NOTE: Relay reestablishment (Disestablished state transition) is covered
|
||||||
|
// by the existing TestReestablishRelays in handshakes_test.go.
|
||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
@@ -749,7 +748,6 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
@@ -771,49 +769,41 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
r.Log("Get a tunnel between me and relay")
|
r.Log("Get a tunnel between me and relay")
|
||||||
l.Info("Get a tunnel between me and relay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Get a tunnel between them and relay")
|
r.Log("Get a tunnel between them and relay")
|
||||||
l.Info("Get a tunnel between them and relay")
|
|
||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
||||||
l.Info("Trigger a handshake from both them and me via relay to them and me")
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
||||||
|
|
||||||
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
||||||
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
||||||
|
|
||||||
r.Log("Wait for a packet from them to me")
|
r.Log("Wait for a packet from them to me; myControl")
|
||||||
l.Info("Wait for a packet from them to me; myControl")
|
|
||||||
r.RouteForAllUntilTxTun(myControl)
|
r.RouteForAllUntilTxTun(myControl)
|
||||||
l.Info("Wait for a packet from them to me; theirControl")
|
r.Log("Wait for a packet from them to me; theirControl")
|
||||||
r.RouteForAllUntilTxTun(theirControl)
|
r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
l.Info("Assert the tunnel works")
|
|
||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
|
||||||
t.Log("Wait until we remove extra tunnels")
|
t.Log("Wait until we remove extra tunnels")
|
||||||
l.Info("Wait until we remove extra tunnels")
|
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
|
||||||
l.WithFields(
|
len(myControl.GetHostmap().Indexes),
|
||||||
logrus.Fields{
|
len(theirControl.GetHostmap().Indexes),
|
||||||
"myControl": len(myControl.GetHostmap().Indexes),
|
len(relayControl.GetHostmap().Indexes),
|
||||||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
)
|
||||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
|
||||||
}).Info("Waiting for hostinfos to be removed...")
|
|
||||||
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
||||||
retries := 60
|
retries := 60
|
||||||
for hostInfos > 6 && retries > 0 {
|
for hostInfos > 6 && retries > 0 {
|
||||||
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
||||||
l.WithFields(
|
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
|
||||||
logrus.Fields{
|
len(myControl.GetHostmap().Indexes),
|
||||||
"myControl": len(myControl.GetHostmap().Indexes),
|
len(theirControl.GetHostmap().Indexes),
|
||||||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
len(relayControl.GetHostmap().Indexes),
|
||||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
)
|
||||||
}).Info("Waiting for hostinfos to be removed...")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
t.Log("Connection manager hasn't ticked yet")
|
t.Log("Connection manager hasn't ticked yet")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -821,7 +811,6 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
l.Info("Assert the tunnel works")
|
|
||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
|
||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
@@ -1369,6 +1358,81 @@ func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLighthouseUpdateOnReload(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
// Create the lighthouse
|
||||||
|
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh", "10.128.0.1/24", m{"lighthouse": m{"am_lighthouse": true}})
|
||||||
|
|
||||||
|
// Create a client with NO lighthouse configured and a long update interval.
|
||||||
|
// The initial SendUpdate at startup will be a no-op since no lighthouses are known.
|
||||||
|
myControl, myVpnIpNet, _, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.2/24", m{
|
||||||
|
"lighthouse": m{
|
||||||
|
"interval": 600,
|
||||||
|
"local_allow_list": m{
|
||||||
|
"10.0.0.0/24": true,
|
||||||
|
"::/0": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
r := router.NewR(t, lhControl, myControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
lhControl.Start()
|
||||||
|
myControl.Start()
|
||||||
|
|
||||||
|
// Drain any startup packets (there should be none meaningful)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
// Verify lighthouse has no knowledge of the client
|
||||||
|
assert.Nil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr()))
|
||||||
|
|
||||||
|
// Build a new config that adds the lighthouse
|
||||||
|
newSettings := make(m)
|
||||||
|
for k, v := range myConfig.Settings {
|
||||||
|
newSettings[k] = v
|
||||||
|
}
|
||||||
|
newSettings["static_host_map"] = m{
|
||||||
|
lhVpnIpNet[0].Addr().String(): []any{lhUdpAddr.String()},
|
||||||
|
}
|
||||||
|
newSettings["lighthouse"] = m{
|
||||||
|
"hosts": []any{lhVpnIpNet[0].Addr().String()},
|
||||||
|
"interval": 600,
|
||||||
|
"local_allow_list": m{
|
||||||
|
"10.0.0.0/24": true,
|
||||||
|
"::/0": false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newCfg, err := yaml.Marshal(newSettings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reload the config. The lighthouse.hosts change triggers TriggerUpdate,
|
||||||
|
// which wakes the update worker. It calls SendUpdate, initiating a
|
||||||
|
// handshake to the new lighthouse and caching the HostUpdateNotification.
|
||||||
|
require.NoError(t, myConfig.ReloadConfigString(string(newCfg)))
|
||||||
|
|
||||||
|
// Route until the lighthouse receives the HostUpdateNotification.
|
||||||
|
// This covers: handshake stage 1, stage 2, then the cached update.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
r.RouteForAllUntilAfterMsgTypeTo(lhControl, header.LightHouse, 0)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for lighthouse update after config reload")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify lighthouse now has the client's addresses
|
||||||
|
assert.NotNil(t, lhControl.QueryLighthouse(myVpnIpNet[0].Addr()))
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", lhControl, myControl)
|
||||||
|
lhControl.Stop()
|
||||||
|
myControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
||||||
unsafePrefix := "192.168.6.0/24"
|
unsafePrefix := "192.168.6.0/24"
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -12,15 +11,18 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"go.yaml.in/yaml/v3"
|
||||||
@@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
|||||||
"port": udpAddr.Port(),
|
"port": udpAddr.Port(),
|
||||||
},
|
},
|
||||||
"logging": m{
|
"logging": m{
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
|
"level": testLogLevelName(),
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
},
|
||||||
"timers": m{
|
"timers": m{
|
||||||
"pending_deletion_interval": 2,
|
"pending_deletion_interval": 2,
|
||||||
@@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o
|
|||||||
"port": udpAddr.Port(),
|
"port": udpAddr.Port(),
|
||||||
},
|
},
|
||||||
"logging": m{
|
"logging": m{
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
"level": testLogLevelName(),
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
},
|
||||||
"timers": m{
|
"timers": m{
|
||||||
"pending_deletion_interval": 2,
|
"pending_deletion_interval": 2,
|
||||||
@@ -379,24 +379,32 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
|
|||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestLogger() *logrus.Logger {
|
func NewTestLogger() *slog.Logger {
|
||||||
l := logrus.New()
|
|
||||||
|
|
||||||
v := os.Getenv("TEST_LOGS")
|
v := os.Getenv("TEST_LOGS")
|
||||||
if v == "" {
|
if v == "" {
|
||||||
l.SetOutput(io.Discard)
|
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
l.SetLevel(logrus.PanicLevel)
|
|
||||||
return l
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
level := slog.LevelInfo
|
||||||
switch v {
|
switch v {
|
||||||
case "2":
|
case "2":
|
||||||
l.SetLevel(logrus.DebugLevel)
|
level = slog.LevelDebug
|
||||||
case "3":
|
case "3":
|
||||||
l.SetLevel(logrus.TraceLevel)
|
level = logging.LevelTrace
|
||||||
default:
|
|
||||||
l.SetLevel(logrus.InfoLevel)
|
|
||||||
}
|
}
|
||||||
|
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
|
||||||
return l
|
}
|
||||||
|
|
||||||
|
// testLogLevelName returns the level name string accepted by logging.ApplyConfig
|
||||||
|
// for the current TEST_LOGS setting. Kept in sync with NewTestLogger.
|
||||||
|
func testLogLevelName() string {
|
||||||
|
switch os.Getenv("TEST_LOGS") {
|
||||||
|
case "2":
|
||||||
|
return "debug"
|
||||||
|
case "3":
|
||||||
|
return "trace"
|
||||||
|
case "":
|
||||||
|
return "info"
|
||||||
|
}
|
||||||
|
return "info"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
@@ -365,3 +367,106 @@ func TestCrossStackRelaysWork(t *testing.T) {
|
|||||||
//theirControl.Stop()
|
//theirControl.Stop()
|
||||||
//relayControl.Stop()
|
//relayControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloseTunnelAuthenticated(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
r.Log("Close the tunnel")
|
||||||
|
myControl.CloseTunnel(theirVpnIpNet[0].Addr(), false)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
myIndexes := len(myControl.GetHostmap().Indexes)
|
||||||
|
theirIndexes := len(theirControl.GetHostmap().Indexes)
|
||||||
|
if myIndexes == 0 && theirIndexes == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
|
||||||
|
if since > time.Second*6 {
|
||||||
|
t.Fatal("Tunnel should have been declared inactive after 2 seconds and before 6 seconds")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
//r.FlushAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logf("Happy path success, tunnels were dropped within %v", time.Since(waitStart))
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
r.Log("Assert another tunnel between me and them works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
hi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
if hi == nil {
|
||||||
|
t.Fatal("There is no hostinfo for this tunnel")
|
||||||
|
}
|
||||||
|
myHi := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
if myHi == nil {
|
||||||
|
t.Fatal("There is no hostinfo for my tunnel")
|
||||||
|
}
|
||||||
|
r.Log("It does")
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
hdr := header.H{
|
||||||
|
Version: 1,
|
||||||
|
Type: header.CloseTunnel,
|
||||||
|
Subtype: 0,
|
||||||
|
Reserved: 0,
|
||||||
|
RemoteIndex: hi.RemoteIndex,
|
||||||
|
MessageCounter: 5,
|
||||||
|
}
|
||||||
|
out, err := hdr.Encode(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := &udp.Packet{
|
||||||
|
To: hi.CurrentRemote,
|
||||||
|
From: myHi.CurrentRemote,
|
||||||
|
Data: out,
|
||||||
|
}
|
||||||
|
r.InjectUDPPacket(myControl, theirControl, pkt)
|
||||||
|
r.Log("Injected bogus close tunnel. Let's see!")
|
||||||
|
waitStart = time.Now()
|
||||||
|
for {
|
||||||
|
myIndexes := len(myControl.GetHostmap().Indexes)
|
||||||
|
theirIndexes := len(theirControl.GetHostmap().Indexes)
|
||||||
|
if myIndexes == 0 {
|
||||||
|
t.Fatal("myIndexes should not be 0")
|
||||||
|
}
|
||||||
|
if theirIndexes == 0 {
|
||||||
|
t.Fatal("theirIndexes should not be 0, they should have rejected this bogus packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
|
||||||
|
if since > time.Second*4 {
|
||||||
|
t.Log("The tunnel would have been gone by now")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|||||||
@@ -204,6 +204,12 @@ punchy:
|
|||||||
# Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access.
|
# Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access.
|
||||||
#trusted_cas:
|
#trusted_cas:
|
||||||
#- "ssh public key string"
|
#- "ssh public key string"
|
||||||
|
# sandbox_dir restricts file paths for profiling commands (start-cpu-profile, save-heap-profile,
|
||||||
|
# save-mutex-profile) to the specified directory. Relative paths will be resolved within this directory,
|
||||||
|
# and absolute paths outside of it will be rejected. Default is $TMP/nebula-debug.
|
||||||
|
# The directory is NOT automatically created.
|
||||||
|
# Overriding this to "" is the same as "/" and will allow overwriting any path on the host.
|
||||||
|
#sandbox_dir: /var/tmp/nebula-debug
|
||||||
|
|
||||||
# EXPERIMENTAL: relay support for networks that can't establish direct connections.
|
# EXPERIMENTAL: relay support for networks that can't establish direct connections.
|
||||||
relay:
|
relay:
|
||||||
@@ -286,24 +292,21 @@ tun:
|
|||||||
|
|
||||||
# Configure logging level
|
# Configure logging level
|
||||||
logging:
|
logging:
|
||||||
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
|
# trace, debug, info, warn, or error. Default is info and is reloadable.
|
||||||
#NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some
|
# fatal and panic are accepted for backwards compatibility and map to error.
|
||||||
# scenarios. Debug logging is also CPU intensive and will decrease performance overall.
|
#NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some
|
||||||
# Only enable debug logging while actively investigating an issue.
|
# scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall.
|
||||||
|
# Only enable debug or trace logging while actively investigating an issue.
|
||||||
level: info
|
level: info
|
||||||
# json or text formats currently available. Default is text
|
# json or text formats currently available. Default is text.
|
||||||
format: text
|
format: text
|
||||||
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
|
# Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false.
|
||||||
#disable_timestamp: true
|
#disable_timestamp: true
|
||||||
# timestamp format is specified in Go time format, see:
|
# Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable.
|
||||||
# https://golang.org/pkg/time/#pkg-constants
|
|
||||||
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
|
||||||
# default when `format: text`:
|
|
||||||
# when TTY attached: seconds since beginning of execution
|
|
||||||
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
|
||||||
# As an example, to log as RFC3339 with millisecond precision, set to:
|
|
||||||
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
|
|
||||||
|
|
||||||
|
# The stats section is reloadable. A HUP may change the backend, toggle stats
|
||||||
|
# on or off, switch the listen/host address, or pick up new DNS for the
|
||||||
|
# configured graphite host.
|
||||||
#stats:
|
#stats:
|
||||||
#type: graphite
|
#type: graphite
|
||||||
#prefix: nebula
|
#prefix: nebula
|
||||||
@@ -321,10 +324,12 @@ logging:
|
|||||||
# enables counter metrics for meta packets
|
# enables counter metrics for meta packets
|
||||||
# e.g.: `messages.tx.handshake`
|
# e.g.: `messages.tx.handshake`
|
||||||
# NOTE: `message.{tx,rx}.recv_error` is always emitted
|
# NOTE: `message.{tx,rx}.recv_error` is always emitted
|
||||||
|
# Not reloadable.
|
||||||
#message_metrics: false
|
#message_metrics: false
|
||||||
|
|
||||||
# enables detailed counter metrics for lighthouse packets
|
# enables detailed counter metrics for lighthouse packets
|
||||||
# e.g.: `lighthouse.rx.HostQuery`
|
# e.g.: `lighthouse.rx.HostQuery`
|
||||||
|
# Not reloadable.
|
||||||
#lighthouse_metrics: false
|
#lighthouse_metrics: false
|
||||||
|
|
||||||
# Handshake Manager Settings
|
# Handshake Manager Settings
|
||||||
@@ -382,8 +387,8 @@ firewall:
|
|||||||
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
|
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
|
||||||
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
|
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
|
||||||
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
|
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
|
||||||
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
|
|
||||||
# proto: `any`, `tcp`, `udp`, or `icmp`
|
# proto: `any`, `tcp`, `udp`, or `icmp`
|
||||||
|
# a port specification is ignored if proto is `icmp`
|
||||||
# host: `any` or a literal hostname, ie `test-host`
|
# host: `any` or a literal hostname, ie `test-host`
|
||||||
# group: `any` or a literal group name, ie `default-group`
|
# group: `any` or a literal group name, ie `default-group`
|
||||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/service"
|
"github.com/slackhq/nebula/service"
|
||||||
)
|
)
|
||||||
@@ -64,8 +64,7 @@ pki:
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := logrus.New()
|
logger := logging.NewLogger(os.Stdout)
|
||||||
logger.Out = os.Stdout
|
|
||||||
|
|
||||||
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
203
firewall.go
203
firewall.go
@@ -1,11 +1,13 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -16,7 +18,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -67,7 +68,7 @@ type Firewall struct {
|
|||||||
incomingMetrics firewallMetrics
|
incomingMetrics firewallMetrics
|
||||||
outgoingMetrics firewallMetrics
|
outgoingMetrics firewallMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type firewallMetrics struct {
|
type firewallMetrics struct {
|
||||||
@@ -131,7 +132,7 @@ type firewallLocalCIDR struct {
|
|||||||
|
|
||||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||||
// The certificate provided should be the highest version loaded in memory.
|
// The certificate provided should be the highest version loaded in memory.
|
||||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
||||||
//TODO: error on 0 duration
|
//TODO: error on 0 duration
|
||||||
var tmin, tmax time.Duration
|
var tmin, tmax time.Duration
|
||||||
|
|
||||||
@@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||||
certificate := cs.getCertificate(cert.Version2)
|
certificate := cs.getCertificate(cert.Version2)
|
||||||
if certificate == nil {
|
if certificate == nil {
|
||||||
certificate = cs.getCertificate(cert.Version1)
|
certificate = cs.getCertificate(cert.Version1)
|
||||||
@@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
case "drop":
|
case "drop":
|
||||||
fw.InSendReject = false
|
fw.InSendReject = false
|
||||||
default:
|
default:
|
||||||
l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`")
|
l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction)
|
||||||
fw.InSendReject = false
|
fw.InSendReject = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
case "drop":
|
case "drop":
|
||||||
fw.OutSendReject = false
|
fw.OutSendReject = false
|
||||||
default:
|
default:
|
||||||
l.WithField("action", inboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`")
|
l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction)
|
||||||
fw.OutSendReject = false
|
fw.OutSendReject = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -249,20 +250,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
|
|
||||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
||||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
|
||||||
ruleString := fmt.Sprintf(
|
|
||||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
|
||||||
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
|
|
||||||
)
|
|
||||||
f.rules += ruleString + "\n"
|
|
||||||
|
|
||||||
direction := "incoming"
|
|
||||||
if !incoming {
|
|
||||||
direction = "outgoing"
|
|
||||||
}
|
|
||||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
|
|
||||||
Info("Firewall rule added")
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ft *FirewallTable
|
ft *FirewallTable
|
||||||
fp firewallPort
|
fp firewallPort
|
||||||
@@ -280,6 +267,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
case firewall.ProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
fp = ft.UDP
|
fp = ft.UDP
|
||||||
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||||
|
//ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided
|
||||||
|
if startPort != firewall.PortAny {
|
||||||
|
f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort)
|
||||||
|
}
|
||||||
|
startPort = firewall.PortAny
|
||||||
|
endPort = firewall.PortAny
|
||||||
fp = ft.ICMP
|
fp = ft.ICMP
|
||||||
case firewall.ProtoAny:
|
case firewall.ProtoAny:
|
||||||
fp = ft.AnyProto
|
fp = ft.AnyProto
|
||||||
@@ -287,6 +280,21 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
return fmt.Errorf("unknown protocol %v", proto)
|
return fmt.Errorf("unknown protocol %v", proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||||
|
ruleString := fmt.Sprintf(
|
||||||
|
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
||||||
|
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
|
||||||
|
)
|
||||||
|
f.rules += ruleString + "\n"
|
||||||
|
|
||||||
|
direction := "incoming"
|
||||||
|
if !incoming {
|
||||||
|
direction = "outgoing"
|
||||||
|
}
|
||||||
|
f.l.Info("Firewall rule added",
|
||||||
|
"firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha},
|
||||||
|
)
|
||||||
|
|
||||||
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string {
|
|||||||
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||||
var table string
|
var table string
|
||||||
if inbound {
|
if inbound {
|
||||||
table = "firewall.inbound"
|
table = "firewall.inbound"
|
||||||
@@ -349,24 +357,31 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
sPort = r.Port
|
sPort = r.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
startPort, endPort, err := parsePort(sPort)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var proto uint8
|
var proto uint8
|
||||||
|
var startPort, endPort int32
|
||||||
switch r.Proto {
|
switch r.Proto {
|
||||||
case "any":
|
case "any":
|
||||||
proto = firewall.ProtoAny
|
proto = firewall.ProtoAny
|
||||||
|
startPort, endPort, err = parsePort(sPort)
|
||||||
case "tcp":
|
case "tcp":
|
||||||
proto = firewall.ProtoTCP
|
proto = firewall.ProtoTCP
|
||||||
|
startPort, endPort, err = parsePort(sPort)
|
||||||
case "udp":
|
case "udp":
|
||||||
proto = firewall.ProtoUDP
|
proto = firewall.ProtoUDP
|
||||||
|
startPort, endPort, err = parsePort(sPort)
|
||||||
case "icmp":
|
case "icmp":
|
||||||
proto = firewall.ProtoICMP
|
proto = firewall.ProtoICMP
|
||||||
|
startPort = firewall.PortAny
|
||||||
|
endPort = firewall.PortAny
|
||||||
|
if sPort != "" {
|
||||||
|
l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
|
||||||
|
}
|
||||||
|
|
||||||
if r.Cidr != "" && r.Cidr != "any" {
|
if r.Cidr != "" && r.Cidr != "any" {
|
||||||
_, err = netip.ParsePrefix(r.Cidr)
|
_, err = netip.ParsePrefix(r.Cidr)
|
||||||
@@ -383,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
}
|
}
|
||||||
|
|
||||||
if warning := r.sanity(); warning != nil {
|
if warning := r.sanity(); warning != nil {
|
||||||
l.Warnf("%s rule #%v; %s", table, i, warning)
|
l.Warn("firewall rule sanity check",
|
||||||
|
"table", table,
|
||||||
|
"rule", i,
|
||||||
|
"warning", warning,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
||||||
@@ -467,7 +486,7 @@ func (f *Firewall) metrics(incoming bool) firewallMetrics {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
|
// Destroy cleans up any known cyclical references so the object can be freed by GC. This should be called if a new
|
||||||
// firewall object is created
|
// firewall object is created
|
||||||
func (f *Firewall) Destroy() {
|
func (f *Firewall) Destroy() {
|
||||||
//TODO: clean references if/when needed
|
//TODO: clean references if/when needed
|
||||||
@@ -515,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
|
|
||||||
// We now know which firewall table to check against
|
// We now know which firewall table to check against
|
||||||
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
h.logger(f.l).
|
h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset",
|
||||||
WithField("fwPacket", fp).
|
"fwPacket", fp,
|
||||||
WithField("incoming", c.incoming).
|
"incoming", c.incoming,
|
||||||
WithField("rulesVersion", f.rulesVersion).
|
"rulesVersion", f.rulesVersion,
|
||||||
WithField("oldRulesVersion", c.rulesVersion).
|
"oldRulesVersion", c.rulesVersion,
|
||||||
Debugln("dropping old conntrack entry, does not match new ruleset")
|
)
|
||||||
}
|
}
|
||||||
delete(conntrack.Conns, fp)
|
delete(conntrack.Conns, fp)
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
h.logger(f.l).
|
h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset",
|
||||||
WithField("fwPacket", fp).
|
"fwPacket", fp,
|
||||||
WithField("incoming", c.incoming).
|
"incoming", c.incoming,
|
||||||
WithField("rulesVersion", f.rulesVersion).
|
"rulesVersion", f.rulesVersion,
|
||||||
WithField("oldRulesVersion", c.rulesVersion).
|
"oldRulesVersion", c.rulesVersion,
|
||||||
Debugln("keeping old conntrack entry, does match new ruleset")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.rulesVersion = f.rulesVersion
|
c.rulesVersion = f.rulesVersion
|
||||||
@@ -660,6 +679,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this branch is here to catch traffic from FirewallTable.Any.match and FirewallTable.ICMP.match
|
||||||
|
if p.Protocol == firewall.ProtoICMP || p.Protocol == firewall.ProtoICMPv6 {
|
||||||
|
// port numbers are re-used for connection tracking of ICMP,
|
||||||
|
// but we don't want to actually filter on them.
|
||||||
|
return fp[firewall.PortAny].match(p, c, caPool)
|
||||||
|
}
|
||||||
|
|
||||||
var port int32
|
var port int32
|
||||||
|
|
||||||
if p.Fragment {
|
if p.Fragment {
|
||||||
@@ -804,10 +830,8 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, group := range groups {
|
if slices.Contains(groups, "any") {
|
||||||
if group == "any" {
|
return true
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if host == "any" {
|
if host == "any" {
|
||||||
@@ -917,7 +941,7 @@ type rule struct {
|
|||||||
CASha string
|
CASha string
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) {
|
||||||
r := rule{}
|
r := rule{}
|
||||||
|
|
||||||
m, ok := p.(map[string]any)
|
m, ok := p.(map[string]any)
|
||||||
@@ -948,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
|||||||
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
l.Warn("group was an array with a single value, converting to simple value",
|
||||||
|
"table", table,
|
||||||
|
"rule", i,
|
||||||
|
)
|
||||||
m["group"] = v[0]
|
m["group"] = v[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1018,54 +1045,56 @@ func (r *rule) sanity() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.Code != "" {
|
||||||
|
return fmt.Errorf("code specified as [%s]. Support for 'code' will be dropped in a future release, as it has never been functional", r.Code)
|
||||||
|
}
|
||||||
|
|
||||||
//todo alert on cidr-any
|
//todo alert on cidr-any
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parsePort(s string) (startPort, endPort int32, err error) {
|
func parsePort(s string) (int32, int32, error) {
|
||||||
|
var err error
|
||||||
|
const notAPort int32 = -2
|
||||||
if s == "any" {
|
if s == "any" {
|
||||||
startPort = firewall.PortAny
|
return firewall.PortAny, firewall.PortAny, nil
|
||||||
endPort = firewall.PortAny
|
}
|
||||||
|
if s == "fragment" {
|
||||||
} else if s == "fragment" {
|
return firewall.PortFragment, firewall.PortFragment, nil
|
||||||
startPort = firewall.PortFragment
|
}
|
||||||
endPort = firewall.PortFragment
|
if !strings.Contains(s, `-`) {
|
||||||
|
|
||||||
} else if strings.Contains(s, `-`) {
|
|
||||||
sPorts := strings.SplitN(s, `-`, 2)
|
|
||||||
sPorts[0] = strings.Trim(sPorts[0], " ")
|
|
||||||
sPorts[1] = strings.Trim(sPorts[1], " ")
|
|
||||||
|
|
||||||
if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
|
|
||||||
return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
rStartPort, err := strconv.Atoi(sPorts[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
rEndPort, err := strconv.Atoi(sPorts[1])
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
startPort = int32(rStartPort)
|
|
||||||
endPort = int32(rEndPort)
|
|
||||||
|
|
||||||
if startPort == firewall.PortAny {
|
|
||||||
endPort = firewall.PortAny
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
rPort, err := strconv.Atoi(s)
|
rPort, err := strconv.Atoi(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, fmt.Errorf("was not a number; `%s`", s)
|
return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s)
|
||||||
}
|
}
|
||||||
startPort = int32(rPort)
|
return int32(rPort), int32(rPort), nil
|
||||||
endPort = startPort
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
sPorts := strings.SplitN(s, `-`, 2)
|
||||||
|
for i := range sPorts {
|
||||||
|
sPorts[i] = strings.Trim(sPorts[i], " ")
|
||||||
|
}
|
||||||
|
if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
|
||||||
|
return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
rStartPort, err := strconv.Atoi(sPorts[0])
|
||||||
|
if err != nil {
|
||||||
|
return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
rEndPort, err := strconv.Atoi(sPorts[1])
|
||||||
|
if err != nil {
|
||||||
|
return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
startPort := int32(rStartPort)
|
||||||
|
endPort := int32(rEndPort)
|
||||||
|
|
||||||
|
if startPort == firewall.PortAny {
|
||||||
|
endPort = firewall.PortAny
|
||||||
|
}
|
||||||
|
|
||||||
|
return startPort, endPort, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||||
@@ -15,41 +15,49 @@ type ConntrackCacheTicker struct {
|
|||||||
cacheV uint64
|
cacheV uint64
|
||||||
cacheTick atomic.Uint64
|
cacheTick atomic.Uint64
|
||||||
|
|
||||||
|
l *slog.Logger
|
||||||
cache ConntrackCache
|
cache ConntrackCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker {
|
||||||
if d == 0 {
|
if d == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &ConntrackCacheTicker{
|
c := &ConntrackCacheTicker{
|
||||||
|
l: l,
|
||||||
cache: ConntrackCache{},
|
cache: ConntrackCache{},
|
||||||
}
|
}
|
||||||
|
|
||||||
go c.tick(d)
|
go c.tick(ctx, d)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
|
||||||
|
t := time.NewTicker(d)
|
||||||
|
defer t.Stop()
|
||||||
for {
|
for {
|
||||||
time.Sleep(d)
|
select {
|
||||||
c.cacheTick.Add(1)
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-t.C:
|
||||||
|
c.cacheTick.Add(1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get checks if the cache ticker has moved to the next version before returning
|
// Get checks if the cache ticker has moved to the next version before returning
|
||||||
// the map. If it has moved, we reset the map.
|
// the map. If it has moved, we reset the map.
|
||||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
func (c *ConntrackCacheTicker) Get() ConntrackCache {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
||||||
c.cacheV = tick
|
c.cacheV = tick
|
||||||
if ll := len(c.cache); ll > 0 {
|
if ll := len(c.cache); ll > 0 {
|
||||||
if l.Level == logrus.DebugLevel {
|
if c.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
c.l.Debug("resetting conntrack cache", "len", ll)
|
||||||
}
|
}
|
||||||
c.cache = make(ConntrackCache, ll)
|
c.cache = make(ConntrackCache, ll)
|
||||||
}
|
}
|
||||||
|
|||||||
69
firewall/cache_test.go
Normal file
69
firewall/cache_test.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The tests below pin the log format produced by ConntrackCacheTicker.Get
|
||||||
|
// so changes cannot silently break what operators are grepping for. The
|
||||||
|
// ticker's internal state (cache + cacheTick) is poked directly to avoid
|
||||||
|
// racing a goroutine-driven tick in tests.
|
||||||
|
|
||||||
|
func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker {
|
||||||
|
t.Helper()
|
||||||
|
c := &ConntrackCacheTicker{
|
||||||
|
l: l,
|
||||||
|
cache: make(ConntrackCache, cacheLen),
|
||||||
|
}
|
||||||
|
for i := 0; i < cacheLen; i++ {
|
||||||
|
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
|
||||||
|
}
|
||||||
|
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 3)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 2)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 5)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.Empty(t, buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 0)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.Empty(t, buf.String())
|
||||||
|
}
|
||||||
@@ -22,7 +22,10 @@ const (
|
|||||||
type Packet struct {
|
type Packet struct {
|
||||||
LocalAddr netip.Addr
|
LocalAddr netip.Addr
|
||||||
RemoteAddr netip.Addr
|
RemoteAddr netip.Addr
|
||||||
LocalPort uint16
|
// LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP.
|
||||||
|
LocalPort uint16
|
||||||
|
// RemotePort is the source port for incoming traffic, or the destination port for outgoing.
|
||||||
|
// For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier
|
||||||
RemotePort uint16
|
RemotePort uint16
|
||||||
Protocol uint8
|
Protocol uint8
|
||||||
Fragment bool
|
Fragment bool
|
||||||
@@ -46,6 +49,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
|
|||||||
proto = "tcp"
|
proto = "tcp"
|
||||||
case ProtoICMP:
|
case ProtoICMP:
|
||||||
proto = "icmp"
|
proto = "icmp"
|
||||||
|
case ProtoICMPv6:
|
||||||
|
proto = "icmpv6"
|
||||||
case ProtoUDP:
|
case ProtoUDP:
|
||||||
proto = "udp"
|
proto = "udp"
|
||||||
default:
|
default:
|
||||||
|
|||||||
260
firewall_test.go
260
firewall_test.go
@@ -3,13 +3,13 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_AddRule(t *testing.T) {
|
func TestFirewall_AddRule(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
c := &dummyCert{}
|
c := &dummyCert{}
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
@@ -87,9 +86,10 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
//no matter what port is given for icmp, it should end up as "any"
|
||||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any)
|
||||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups)
|
||||||
|
assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
|
||||||
@@ -176,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
@@ -253,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropV6(t *testing.T) {
|
func TestFirewall_DropV6(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
@@ -484,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop2(t *testing.T) {
|
func TestFirewall_Drop2(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -543,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3(t *testing.T) {
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -632,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3V6(t *testing.T) {
|
func TestFirewall_Drop3V6(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
|
|
||||||
@@ -670,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -734,10 +728,152 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
|
network := netip.MustParsePrefix("1.2.3.4/24")
|
||||||
|
|
||||||
|
c := cert.CachedCertificate{
|
||||||
|
Certificate: &dummyCert{
|
||||||
|
name: "host1",
|
||||||
|
networks: []netip.Prefix{network},
|
||||||
|
groups: []string{"default-group"},
|
||||||
|
issuer: "signer-shasum",
|
||||||
|
},
|
||||||
|
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||||
|
}
|
||||||
|
h := HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: &c,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
|
}
|
||||||
|
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||||
|
|
||||||
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
|
templ := firewall.Packet{
|
||||||
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
|
Protocol: firewall.ProtoICMP,
|
||||||
|
Fragment: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ICMP allowed", func(t *testing.T) {
|
||||||
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||||
|
t.Run("zero ports", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0
|
||||||
|
p.RemotePort = 0
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||||
|
//now also allow outbound
|
||||||
|
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonzero ports", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0xabcd
|
||||||
|
p.RemotePort = 0x1234
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||||
|
//now also allow outbound
|
||||||
|
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Any proto, some ports allowed", func(t *testing.T) {
|
||||||
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
|
||||||
|
t.Run("zero ports, still blocked", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0
|
||||||
|
p.RemotePort = 0
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
//now also allow outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonzero ports, still blocked", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0xabcd
|
||||||
|
p.RemotePort = 0x1234
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
//now also allow outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 80
|
||||||
|
p.RemotePort = 80
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
//now also allow outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("Any proto, any port", func(t *testing.T) {
|
||||||
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||||
|
t.Run("zero ports, allowed", func(t *testing.T) {
|
||||||
|
resetConntrack(fw)
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0
|
||||||
|
p.RemotePort = 0
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||||
|
//now also allow outbound
|
||||||
|
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonzero ports, allowed", func(t *testing.T) {
|
||||||
|
resetConntrack(fw)
|
||||||
|
p := templ.Copy()
|
||||||
|
p.LocalPort = 0xabcd
|
||||||
|
p.RemotePort = 0x1234
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||||
|
//now also allow outbound
|
||||||
|
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||||
|
//different ID is blocked
|
||||||
|
p.RemotePort++
|
||||||
|
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||||
|
ob := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||||
|
|
||||||
@@ -900,53 +1036,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||||||
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||||
|
|
||||||
// Test both port and code
|
// Test both port and code
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||||
|
|
||||||
// Test missing host, group, cidr, ca_name and ca_sha
|
// Test missing host, group, cidr, ca_name and ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||||
|
|
||||||
// Test code/port error
|
// Test code/port error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||||
|
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||||
|
|
||||||
// Test proto error
|
// Test proto error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||||
|
|
||||||
// Test cidr parse error
|
// Test cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test local_cidr parse error
|
// Test local_cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test both group and groups
|
// Test both group and groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||||
@@ -955,28 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
// Test adding tcp rule
|
// Test adding tcp rule
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
mf := &mockFirewall{}
|
mf := &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding udp rule
|
// Test adding udp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding icmp rule
|
// Test adding icmp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
|
// Test adding icmp rule no port
|
||||||
|
conf = config.NewC(test.NewLogger())
|
||||||
|
mf = &mockFirewall{}
|
||||||
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
|
||||||
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding any rule
|
// Test adding any rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
@@ -984,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test adding rule with cidr
|
// Test adding rule with cidr
|
||||||
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with local_cidr
|
// Test adding rule with local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
@@ -999,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test adding rule with cidr ipv6
|
// Test adding rule with cidr ipv6
|
||||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
cidr6 := netip.MustParsePrefix("fd00::/8")
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any cidr
|
// Test adding rule with any cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with junk cidr
|
// Test adding rule with junk cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
||||||
|
|
||||||
// Test adding rule with local_cidr ipv6
|
// Test adding rule with local_cidr ipv6
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any local_cidr
|
// Test adding rule with any local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with junk local_cidr
|
// Test adding rule with junk local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_name
|
// Test adding rule with ca_name
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
|
||||||
|
|
||||||
// Test single group
|
// Test single group
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test single groups
|
// Test single groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test multiple AND groups
|
// Test multiple AND groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test Add error
|
// Test Add error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
mf.nextCallReturn = errors.New("test error")
|
mf.nextCallReturn = errors.New("test error")
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
@@ -1082,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRule(t *testing.T) {
|
func TestFirewall_convertRule(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
// Ensure group array of 1 is converted and a warning is printed
|
// Ensure group array of 1 is converted and a warning is printed
|
||||||
c := map[string]any{
|
c := map[string]any{
|
||||||
@@ -1092,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
r, err := convertRule(l, c, "test", 1)
|
||||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value")
|
||||||
|
assert.Contains(t, ob.String(), "table=test")
|
||||||
|
assert.Contains(t, ob.String(), "rule=1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"group1"}, r.Groups)
|
assert.Equal(t, []string{"group1"}, r.Groups)
|
||||||
|
|
||||||
@@ -1118,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRuleSanity(t *testing.T) {
|
func TestFirewall_convertRuleSanity(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
noWarningPlease := []map[string]any{
|
noWarningPlease := []map[string]any{
|
||||||
{"group": "group1"},
|
{"group": "group1"},
|
||||||
@@ -1234,7 +1377,7 @@ type testsetup struct {
|
|||||||
fw *Firewall
|
fw *Firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup {
|
||||||
c := dummyCert{
|
c := dummyCert{
|
||||||
name: "me",
|
name: "me",
|
||||||
networks: myPrefixes,
|
networks: myPrefixes,
|
||||||
@@ -1245,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
|
|||||||
return newSetupFromCert(t, l, c)
|
return newSetupFromCert(t, l, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup {
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
for _, prefix := range c.Networks() {
|
for _, prefix := range c.Networks() {
|
||||||
myVpnNetworksTable.Insert(prefix)
|
myVpnNetworksTable.Insert(prefix)
|
||||||
@@ -1262,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
|||||||
|
|
||||||
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
||||||
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
||||||
|
|||||||
24
go.mod
24
go.mod
@@ -1,9 +1,10 @@
|
|||||||
module github.com/slackhq/nebula
|
module github.com/slackhq/nebula
|
||||||
|
|
||||||
go 1.25
|
go 1.25.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
dario.cat/mergo v1.0.2
|
dario.cat/mergo v1.0.2
|
||||||
|
filippo.io/bigmod v0.1.0
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
||||||
github.com/armon/go-radix v1.0.0
|
github.com/armon/go-radix v1.0.0
|
||||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||||
@@ -12,26 +13,25 @@ require (
|
|||||||
github.com/gogo/protobuf v1.3.2
|
github.com/gogo/protobuf v1.3.2
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/kardianos/service v1.2.4
|
github.com/kardianos/service v1.2.4
|
||||||
github.com/miekg/dns v1.1.70
|
github.com/miekg/dns v1.1.72
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
github.com/miekg/pkcs11 v1.1.2
|
||||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||||
github.com/sirupsen/logrus v1.9.4
|
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
go.yaml.in/yaml/v3 v3.0.4
|
go.yaml.in/yaml/v3 v3.0.4
|
||||||
golang.org/x/crypto v0.47.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||||
golang.org/x/net v0.49.0
|
golang.org/x/net v0.52.0
|
||||||
golang.org/x/sync v0.19.0
|
golang.org/x/sync v0.20.0
|
||||||
golang.org/x/sys v0.40.0
|
golang.org/x/sys v0.43.0
|
||||||
golang.org/x/term v0.39.0
|
golang.org/x/term v0.42.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.6.1
|
||||||
google.golang.org/protobuf v1.36.11
|
google.golang.org/protobuf v1.36.11
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||||
@@ -49,7 +49,7 @@ require (
|
|||||||
github.com/prometheus/procfs v0.16.1 // indirect
|
github.com/prometheus/procfs v0.16.1 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||||
golang.org/x/mod v0.31.0 // indirect
|
golang.org/x/mod v0.34.0 // indirect
|
||||||
golang.org/x/time v0.5.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.40.0 // indirect
|
golang.org/x/tools v0.43.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
44
go.sum
44
go.sum
@@ -1,6 +1,8 @@
|
|||||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||||
|
filippo.io/bigmod v0.1.0 h1:UNzDk7y9ADKST+axd9skUpBQeW7fG2KrTZyOE4uGQy8=
|
||||||
|
filippo.io/bigmod v0.1.0/go.mod h1:OjOXDNlClLblvXdwgFFOQFJEocLhhtai8vGLy0JCZlI=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||||
@@ -83,10 +85,10 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA=
|
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
|
||||||
github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
|
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
github.com/miekg/pkcs11 v1.1.2 h1:/VxmeAX5qU6Q3EwafypogwWbYryHFmF2RpkJmw3m4MQ=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
github.com/miekg/pkcs11 v1.1.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||||
@@ -131,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
|
|||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||||
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
||||||
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
|
|
||||||
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
|
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
||||||
@@ -162,16 +162,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -191,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
@@ -223,8 +223,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
|||||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
@@ -233,8 +233,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
|
|||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.6.1 h1:XMaKojH1Hs/raMrmnir4n35nTvzvWj7NmSYzHn2F4qU=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
golang.zx2c4.com/wireguard/windows v0.6.1/go.mod h1:04aqInu5GYuTFvMuDw/rKBAF7mHrltW/3rekpfbbZDM=
|
||||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||||
|
|||||||
551
handshake_ix.go
551
handshake_ix.go
@@ -2,11 +2,12 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
@@ -18,8 +19,11 @@ import (
|
|||||||
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||||
err := f.handshakeManager.allocateIndex(hh)
|
err := f.handshakeManager.allocateIndex(hh)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Failed to generate index",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
"error", err,
|
||||||
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,28 +43,32 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
crt := cs.getCertificate(v)
|
crt := cs.getCertificate(v)
|
||||||
if crt == nil {
|
if crt == nil {
|
||||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Unable to handshake with host because no certificate is available",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
WithField("certVersion", v).
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
"certVersion", v,
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
crtHs := cs.getHandshakeBytes(v)
|
crtHs := cs.getHandshakeBytes(v)
|
||||||
if crtHs == nil {
|
if crtHs == nil {
|
||||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Unable to handshake with host because no certificate handshake bytes is available",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
WithField("certVersion", v).
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
"certVersion", v,
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Failed to create connection state",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
"error", err,
|
||||||
WithField("certVersion", v).
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
Error("Failed to create connection state")
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
|
"certVersion", v,
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
hh.hostinfo.ConnectionState = ci
|
hh.hostinfo.ConnectionState = ci
|
||||||
@@ -76,9 +84,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
hsBytes, err := hs.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Failed to marshal handshake message",
|
||||||
WithField("certVersion", v).
|
"error", err,
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
|
"certVersion", v,
|
||||||
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,8 +97,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Failed to call noise.WriteMessage",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
"error", err,
|
||||||
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,18 +118,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
crt := cs.GetDefaultCertificate()
|
crt := cs.GetDefaultCertificate()
|
||||||
if crt == nil {
|
if crt == nil {
|
||||||
f.l.WithField("from", via).
|
f.l.Error("Unable to handshake with host because no certificate is available",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
"from", via,
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
"certVersion", cs.initiatingVersion,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Error("Failed to create connection state",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"error", err,
|
||||||
Error("Failed to create connection state")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,26 +141,32 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Error("Failed to call noise.ReadMessage",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"error", err,
|
||||||
Error("Failed to call noise.ReadMessage")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = hs.Unmarshal(msg)
|
err = hs.Unmarshal(msg)
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Error("Failed unmarshal handshake message",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"error", err,
|
||||||
Error("Failed unmarshal handshake message")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Info("Handshake did not contain a certificate",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"error", err,
|
||||||
Info("Handshake did not contain a certificate")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,23 +177,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("from", via).
|
attrs := []slog.Attr{
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
slog.Any("error", err),
|
||||||
WithField("certVpnNetworks", rc.Networks()).
|
slog.Any("from", via),
|
||||||
WithField("certFingerprint", fp)
|
slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}),
|
||||||
|
slog.Any("certVpnNetworks", rc.Networks()),
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
slog.String("certFingerprint", fp),
|
||||||
e = e.WithField("cert", rc)
|
}
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
attrs = append(attrs, slog.Any("cert", rc))
|
||||||
}
|
}
|
||||||
|
|
||||||
e.Info("Invalid certificate from host")
|
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
|
||||||
|
// callers grow conditionally, which has no pair-form equivalent.
|
||||||
|
//nolint:sloglint
|
||||||
|
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
||||||
f.l.WithField("from", via).
|
f.l.Info("public key mismatch between certificate and handshake",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"from", via,
|
||||||
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
"cert", remoteCert,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,12 +208,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
if myCertOtherVersion == nil {
|
if myCertOtherVersion == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithError(err).WithFields(m{
|
f.l.Debug("Might be unable to handshake with host due to missing certificate version",
|
||||||
"from": via,
|
"error", err,
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
"from", via,
|
||||||
"cert": remoteCert,
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
"cert", remoteCert,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Record the certificate we are actually using
|
// Record the certificate we are actually using
|
||||||
@@ -192,10 +223,12 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Info("No networks in certificate",
|
||||||
WithField("cert", remoteCert).
|
"error", err,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"from", via,
|
||||||
Info("No networks in certificate")
|
"cert", remoteCert,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,12 +242,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
for i, network := range vpnNetworks {
|
for i, network := range vpnNetworks {
|
||||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
|
f.l.Error("Refusing to handshake with myself",
|
||||||
WithField("certName", certName).
|
"vpnNetworks", vpnNetworks,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrs[i] = network.Addr()
|
vpnAddrs[i] = network.Addr()
|
||||||
@@ -226,20 +262,28 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
// We only want to apply the remote allow list for direct tunnels here
|
// We only want to apply the remote allow list for direct tunnels here
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
Debug("lighthouse.remote_allow_list denied incoming handshake")
|
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||||
|
"vpnAddrs", vpnAddrs,
|
||||||
|
"from", via,
|
||||||
|
)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
myIndex, err := generateIndex(f.l)
|
myIndex, err := generateIndex(f.l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to generate index",
|
||||||
WithField("certName", certName).
|
"error", err,
|
||||||
WithField("certVersion", certVersion).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("fingerprint", fingerprint).
|
"from", via,
|
||||||
WithField("issuer", issuer).
|
"certName", certName,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
"certVersion", certVersion,
|
||||||
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,18 +301,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgRxL := f.l.WithFields(m{
|
msgRxL := f.l.With(
|
||||||
"vpnAddrs": vpnAddrs,
|
"vpnAddrs", vpnAddrs,
|
||||||
"from": via,
|
"from", via,
|
||||||
"certName": certName,
|
"certName", certName,
|
||||||
"certVersion": certVersion,
|
"certVersion", certVersion,
|
||||||
"fingerprint": fingerprint,
|
"fingerprint", fingerprint,
|
||||||
"issuer": issuer,
|
"issuer", issuer,
|
||||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
"responderIndex": hs.Details.ResponderIndex,
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
"remoteIndex": h.RemoteIndex,
|
"remoteIndex", h.RemoteIndex,
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
})
|
)
|
||||||
|
|
||||||
if anyVpnAddrsInCommon {
|
if anyVpnAddrsInCommon {
|
||||||
msgRxL.Info("Handshake message received")
|
msgRxL.Info("Handshake message received")
|
||||||
@@ -280,8 +324,9 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available",
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
"myCertVersion", ci.myCert.Version(),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,32 +336,43 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
hsBytes, err := hs.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to marshal handshake message",
|
||||||
WithField("certName", certName).
|
"error", err,
|
||||||
WithField("certVersion", certVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("fingerprint", fingerprint).
|
"from", via,
|
||||||
WithField("issuer", issuer).
|
"certName", certName,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
"certVersion", certVersion,
|
||||||
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||||
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to call noise.WriteMessage",
|
||||||
WithField("certName", certName).
|
"error", err,
|
||||||
WithField("certVersion", certVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("fingerprint", fingerprint).
|
"from", via,
|
||||||
WithField("issuer", issuer).
|
"certName", certName,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
"certVersion", certVersion,
|
||||||
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Noise did not arrive at a key",
|
||||||
WithField("certName", certName).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -358,13 +414,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
err := f.outside.WriteTo(msg, via.UdpAddr)
|
err := f.outside.WriteTo(msg, via.UdpAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to send handshake message",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
"vpnAddrs", existing.vpnAddrs,
|
||||||
WithError(err).Error("Failed to send handshake message")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"cached", true,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
f.l.Info("Handshake message sent",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
"vpnAddrs", existing.vpnAddrs,
|
||||||
Info("Handshake message sent")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"cached", true,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
@@ -374,50 +437,67 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
}
|
}
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
f.l.Info("Handshake message sent",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
"vpnAddrs", existing.vpnAddrs,
|
||||||
Info("Handshake message sent")
|
"relay", via.relayHI.vpnAddrs[0],
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"cached", true,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case ErrExistingHostInfo:
|
case ErrExistingHostInfo:
|
||||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.Info("Handshake too old",
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
"certName", certName,
|
||||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
"certVersion", certVersion,
|
||||||
WithField("fingerprint", fingerprint).
|
"oldHandshakeTime", existing.lastHandshakeTime,
|
||||||
WithField("issuer", issuer).
|
"newHandshakeTime", hostinfo.lastHandshakeTime,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"issuer", issuer,
|
||||||
Info("Handshake too old")
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
|
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
return
|
return
|
||||||
case ErrLocalIndexCollision:
|
case ErrLocalIndexCollision:
|
||||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to add HostInfo due to localIndex collision",
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"issuer", issuer,
|
||||||
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
Error("Failed to add HostInfo due to localIndex collision")
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
"localIndex", hostinfo.localIndexId,
|
||||||
|
"collision", existing.vpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||||
// And we forget to update it here
|
// And we forget to update it here
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to add HostInfo to HostMap",
|
||||||
WithField("certName", certName).
|
"error", err,
|
||||||
WithField("certVersion", certVersion).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("fingerprint", fingerprint).
|
"from", via,
|
||||||
WithField("issuer", issuer).
|
"certName", certName,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"certVersion", certVersion,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"fingerprint", fingerprint,
|
||||||
Error("Failed to add HostInfo to HostMap")
|
"issuer", issuer,
|
||||||
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -426,15 +506,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
err = f.outside.WriteTo(msg, via.UdpAddr)
|
err = f.outside.WriteTo(msg, via.UdpAddr)
|
||||||
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
log := f.l.With(
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
"issuer", issuer,
|
||||||
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Failed to send handshake")
|
log.Error("Failed to send handshake", "error", err)
|
||||||
} else {
|
} else {
|
||||||
log.Info("Handshake message sent")
|
log.Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
@@ -448,20 +533,29 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
// it's correctly marked as working.
|
// it's correctly marked as working.
|
||||||
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
|
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
|
||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
f.l.Info("Handshake message sent",
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"relay", via.relayHI.vpnAddrs[0],
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"issuer", issuer,
|
||||||
Info("Handshake message sent")
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||||
|
|
||||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||||
|
|
||||||
|
// Don't wait for UpdateWorker
|
||||||
|
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
|
||||||
|
f.lightHouse.TriggerUpdate()
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -478,7 +572,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"from", via,
|
||||||
|
)
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -486,18 +585,24 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to call noise.ReadMessage",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
"error", err,
|
||||||
Error("Failed to call noise.ReadMessage")
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"header", h,
|
||||||
|
)
|
||||||
|
|
||||||
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
|
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
|
||||||
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
|
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
|
||||||
// near future
|
// near future
|
||||||
return false
|
return false
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Noise did not arrive at a key",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Error("Noise did not arrive at a key")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
|
|
||||||
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
|
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
|
||||||
// the handshake state machine. Tear it down
|
// the handshake state machine. Tear it down
|
||||||
@@ -507,8 +612,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = hs.Unmarshal(msg)
|
err = hs.Unmarshal(msg)
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed unmarshal handshake message",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
"error", err,
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
|
|
||||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||||
return true
|
return true
|
||||||
@@ -516,10 +625,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Info("Handshake did not contain a certificate",
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
"error", err,
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"from", via,
|
||||||
Info("Handshake did not contain a certificate")
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -530,32 +641,41 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("from", via).
|
attrs := []slog.Attr{
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
slog.Any("error", err),
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
slog.Any("from", via),
|
||||||
WithField("certFingerprint", fp).
|
slog.Any("vpnAddrs", hostinfo.vpnAddrs),
|
||||||
WithField("certVpnNetworks", rc.Networks())
|
slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}),
|
||||||
|
slog.String("certFingerprint", fp),
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
slog.Any("certVpnNetworks", rc.Networks()),
|
||||||
e = e.WithField("cert", rc)
|
}
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
attrs = append(attrs, slog.Any("cert", rc))
|
||||||
}
|
}
|
||||||
|
|
||||||
e.Info("Invalid certificate from host")
|
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
|
||||||
|
// callers grow conditionally, which has no pair-form equivalent.
|
||||||
|
//nolint:sloglint
|
||||||
|
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
||||||
f.l.WithField("from", via).
|
f.l.Info("public key mismatch between certificate and handshake",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"from", via,
|
||||||
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"cert", remoteCert,
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Info("No networks in certificate",
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
"error", err,
|
||||||
WithField("cert", remoteCert).
|
"from", via,
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Info("No networks in certificate")
|
"cert", remoteCert,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -596,12 +716,14 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if !correctHostResponded {
|
if !correctHostResponded {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.Info("Incorrect host responded to handshake",
|
||||||
WithField("from", via).
|
"intendedVpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("certName", certName).
|
"haveVpnNetworks", vpnNetworks,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"certName", certName,
|
||||||
Info("Incorrect host responded to handshake")
|
"certVersion", certVersion,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
|
|
||||||
// Release our old handshake from pending, it should not continue
|
// Release our old handshake from pending, it should not continue
|
||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
@@ -613,10 +735,11 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
newHH.hostinfo.remotes.BlockRemote(via)
|
newHH.hostinfo.remotes.BlockRemote(via)
|
||||||
|
|
||||||
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
|
f.l.Info("Blocked addresses for handshakes",
|
||||||
WithField("vpnNetworks", vpnNetworks).
|
"blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(),
|
||||||
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
|
"vpnNetworks", vpnNetworks,
|
||||||
Info("Blocked addresses for handshakes")
|
"remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()),
|
||||||
|
)
|
||||||
|
|
||||||
// Swap the packet store to benefit the original intended recipient
|
// Swap the packet store to benefit the original intended recipient
|
||||||
newHH.packetStore = hh.packetStore
|
newHH.packetStore = hh.packetStore
|
||||||
@@ -634,15 +757,20 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
msgRxL := f.l.With(
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"issuer", issuer,
|
||||||
WithField("durationNs", duration).
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
WithField("sentCachedPackets", len(hh.packetStore))
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"durationNs", duration,
|
||||||
|
"sentCachedPackets", len(hh.packetStore),
|
||||||
|
)
|
||||||
if anyVpnAddrsInCommon {
|
if anyVpnAddrsInCommon {
|
||||||
msgRxL.Info("Handshake message received")
|
msgRxL.Info("Handshake message received")
|
||||||
} else {
|
} else {
|
||||||
@@ -658,8 +786,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
hostinfo.logger(f.l).Debug("Sending stored packets",
|
||||||
|
"count", len(hh.packetStore),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(hh.packetStore) > 0 {
|
if len(hh.packetStore) > 0 {
|
||||||
@@ -674,5 +804,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||||
f.metricHandshakes.Update(duration)
|
f.metricHandshakes.Update(duration)
|
||||||
|
|
||||||
|
// Don't wait for UpdateWorker
|
||||||
|
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
|
||||||
|
f.lightHouse.TriggerUpdate()
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
@@ -59,7 +59,7 @@ type HandshakeManager struct {
|
|||||||
metricInitiated metrics.Counter
|
metricInitiated metrics.Counter
|
||||||
metricTimedOut metrics.Counter
|
metricTimedOut metrics.Counter
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
// can be used to trigger outbound handshake for the given vpnIp
|
// can be used to trigger outbound handshake for the given vpnIp
|
||||||
trigger chan netip.Addr
|
trigger chan netip.Addr
|
||||||
@@ -78,32 +78,32 @@ type HandshakeHostInfo struct {
|
|||||||
hostinfo *HostInfo
|
hostinfo *HostInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
||||||
if len(hh.packetStore) < 100 {
|
if len(hh.packetStore) < 100 {
|
||||||
tempPacket := make([]byte, len(packet))
|
tempPacket := make([]byte, len(packet))
|
||||||
copy(tempPacket, packet)
|
copy(tempPacket, packet)
|
||||||
|
|
||||||
hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
|
hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hh.hostinfo.logger(l).
|
hh.hostinfo.logger(l).Debug("Packet store",
|
||||||
WithField("length", len(hh.packetStore)).
|
"length", len(hh.packetStore),
|
||||||
WithField("stored", true).
|
"stored", true,
|
||||||
Debugf("Packet store")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
m.dropped.Inc(1)
|
m.dropped.Inc(1)
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hh.hostinfo.logger(l).
|
hh.hostinfo.logger(l).Debug("Packet store",
|
||||||
WithField("length", len(hh.packetStore)).
|
"length", len(hh.packetStore),
|
||||||
WithField("stored", false).
|
"stored", false,
|
||||||
Debugf("Packet store")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
|
func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
|
||||||
return &HandshakeManager{
|
return &HandshakeManager{
|
||||||
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
|
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
|
||||||
indexes: map[uint32]*HandshakeHostInfo{},
|
indexes: map[uint32]*HandshakeHostInfo{},
|
||||||
@@ -140,7 +140,7 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
|
|||||||
// First remote allow list check before we know the vpnIp
|
// First remote allow list check before we know the vpnIp
|
||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
|
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
|
||||||
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -183,12 +183,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hostinfo := hh.hostinfo
|
hostinfo := hh.hostinfo
|
||||||
// If we are out of time, clean up
|
// If we are out of time, clean up
|
||||||
if hh.counter >= hm.config.retries {
|
if hh.counter >= hm.config.retries {
|
||||||
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
|
hh.hostinfo.logger(hm.l).Info("Handshake timed out",
|
||||||
WithField("initiatorIndex", hh.hostinfo.localIndexId).
|
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
|
||||||
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
|
"initiatorIndex", hh.hostinfo.localIndexId,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"remoteIndex", hh.hostinfo.remoteIndexId,
|
||||||
WithField("durationNs", time.Since(hh.startTime).Nanoseconds()).
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
Info("Handshake timed out")
|
"durationNs", time.Since(hh.startTime).Nanoseconds(),
|
||||||
|
)
|
||||||
hm.metricTimedOut.Inc(1)
|
hm.metricTimedOut.Inc(1)
|
||||||
hm.DeleteHostInfo(hostinfo)
|
hm.DeleteHostInfo(hostinfo)
|
||||||
return
|
return
|
||||||
@@ -241,10 +242,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||||
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).WithField("udpAddr", addr).
|
hostinfo.logger(hm.l).Error("Failed to send handshake message",
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
"udpAddr", addr,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"initiatorIndex", hostinfo.localIndexId,
|
||||||
WithError(err).Error("Failed to send handshake message")
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
sentTo = append(sentTo, addr)
|
sentTo = append(sentTo, addr)
|
||||||
@@ -254,19 +257,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
|
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
|
||||||
// so only log when the list of remotes has changed
|
// so only log when the list of remotes has changed
|
||||||
if remotesHaveChanged {
|
if remotesHaveChanged {
|
||||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
hostinfo.logger(hm.l).Info("Handshake message sent",
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
"udpAddrs", sentTo,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"initiatorIndex", hostinfo.localIndexId,
|
||||||
Info("Handshake message sent")
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
} else if hm.l.Level >= logrus.DebugLevel {
|
)
|
||||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
} else if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
hostinfo.logger(hm.l).Debug("Handshake message sent",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"udpAddrs", sentTo,
|
||||||
Debug("Handshake message sent")
|
"initiatorIndex", hostinfo.localIndexId,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
|
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
|
||||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
|
||||||
// Send a RelayRequest to all known Relay IP's
|
// Send a RelayRequest to all known Relay IP's
|
||||||
for _, relay := range hostinfo.remotes.relays {
|
for _, relay := range hostinfo.remotes.relays {
|
||||||
// Don't relay through the host I'm trying to connect to
|
// Don't relay through the host I'm trying to connect to
|
||||||
@@ -281,7 +286,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
|
|
||||||
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
|
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
|
||||||
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
|
hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String())
|
||||||
hm.f.Handshake(relay)
|
hm.f.Handshake(relay)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -292,7 +297,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
if relayHostInfo.remote.IsValid() {
|
if relayHostInfo.remote.IsValid() {
|
||||||
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m := NebulaControl{
|
m := NebulaControl{
|
||||||
@@ -326,17 +331,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
|
|
||||||
msg, err := m.Marshal()
|
msg, err := m.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
|
||||||
WithError(err).
|
|
||||||
Error("Failed to marshal Control message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.Info("send CreateRelayRequest",
|
||||||
"relayFrom": hm.f.myVpnAddrs[0],
|
"relayFrom", hm.f.myVpnAddrs[0],
|
||||||
"relayTo": vpnIp,
|
"relayTo", vpnIp,
|
||||||
"initiatorRelayIndex": idx,
|
"initiatorRelayIndex", idx,
|
||||||
"relay": relay}).
|
"relay", relay,
|
||||||
Info("send CreateRelayRequest")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -344,14 +347,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
|
|
||||||
switch existingRelay.State {
|
switch existingRelay.State {
|
||||||
case Established:
|
case Established:
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
|
hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
|
||||||
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
|
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
|
||||||
case Disestablished:
|
case Disestablished:
|
||||||
// Mark this relay as 'requested'
|
// Mark this relay as 'requested'
|
||||||
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
|
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
|
||||||
fallthrough
|
fallthrough
|
||||||
case Requested:
|
case Requested:
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
|
hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String())
|
||||||
// Re-send the CreateRelay request, in case the previous one was lost.
|
// Re-send the CreateRelay request, in case the previous one was lost.
|
||||||
m := NebulaControl{
|
m := NebulaControl{
|
||||||
Type: NebulaControl_CreateRelayRequest,
|
Type: NebulaControl_CreateRelayRequest,
|
||||||
@@ -383,28 +386,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
}
|
}
|
||||||
msg, err := m.Marshal()
|
msg, err := m.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
|
||||||
WithError(err).
|
|
||||||
Error("Failed to marshal Control message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.Info("send CreateRelayRequest",
|
||||||
"relayFrom": hm.f.myVpnAddrs[0],
|
"relayFrom", hm.f.myVpnAddrs[0],
|
||||||
"relayTo": vpnIp,
|
"relayTo", vpnIp,
|
||||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
"initiatorRelayIndex", existingRelay.LocalIndex,
|
||||||
"relay": relay}).
|
"relay", relay,
|
||||||
Info("send CreateRelayRequest")
|
)
|
||||||
}
|
}
|
||||||
case PeerRequested:
|
case PeerRequested:
|
||||||
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
|
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
|
||||||
fallthrough
|
fallthrough
|
||||||
default:
|
default:
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Error("Relay unexpected state",
|
||||||
WithField("vpnIp", vpnIp).
|
"vpnIp", vpnIp,
|
||||||
WithField("state", existingRelay.State).
|
"state", existingRelay.State,
|
||||||
WithField("relay", relay).
|
"relay", relay,
|
||||||
Errorf("Relay unexpected state")
|
)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -549,9 +550,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
|||||||
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
|
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
|
||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
"remoteIndex", hostinfo.remoteIndexId,
|
||||||
Info("New host shadows existing host remoteIndex")
|
"collision", existingRemoteIndex.vpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
|
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
|
||||||
@@ -571,9 +573,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
|||||||
if found && existingRemoteIndex != nil {
|
if found && existingRemoteIndex != nil {
|
||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
"remoteIndex", hostinfo.remoteIndexId,
|
||||||
Info("New host shadows existing host remoteIndex")
|
"collision", existingRemoteIndex.vpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
|
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
|
||||||
@@ -590,7 +593,7 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
|
|||||||
hm.Lock()
|
hm.Lock()
|
||||||
defer hm.Unlock()
|
defer hm.Unlock()
|
||||||
|
|
||||||
for i := 0; i < 32; i++ {
|
for range 32 {
|
||||||
index, err := generateIndex(hm.l)
|
index, err := generateIndex(hm.l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -629,10 +632,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
|||||||
hm.indexes = map[uint32]*HandshakeHostInfo{}
|
hm.indexes = map[uint32]*HandshakeHostInfo{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
|
hm.l.Debug("Pending hostmap hostInfo deleted",
|
||||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
"hostMap", m{"mapTotalSize": len(hm.vpnIps),
|
||||||
Debug("Pending hostmap hostInfo deleted")
|
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -700,7 +704,7 @@ func (hm *HandshakeManager) EmitStats() {
|
|||||||
|
|
||||||
// Utility functions below
|
// Utility functions below
|
||||||
|
|
||||||
func generateIndex(l *logrus.Logger) (uint32, error) {
|
func generateIndex(l *slog.Logger) (uint32, error) {
|
||||||
b := make([]byte, 4)
|
b := make([]byte, 4)
|
||||||
|
|
||||||
// Let zero mean we don't know the ID, so don't generate zero
|
// Let zero mean we don't know the ID, so don't generate zero
|
||||||
@@ -708,16 +712,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
|
|||||||
for index == 0 {
|
for index == 0 {
|
||||||
_, err := rand.Read(b)
|
_, err := rand.Read(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorln(err)
|
l.Error("Failed to generate index", "error", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
index = binary.BigEndian.Uint32(b)
|
index = binary.BigEndian.Uint32(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("index", index).
|
l.Debug("Generated index", "index", index)
|
||||||
Debug("Generated index")
|
|
||||||
}
|
}
|
||||||
return index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
|
|||||||
80
hostmap.go
80
hostmap.go
@@ -1,9 +1,11 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -13,10 +15,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||||
@@ -60,7 +62,7 @@ type HostMap struct {
|
|||||||
RemoteIndexes map[uint32]*HostInfo
|
RemoteIndexes map[uint32]*HostInfo
|
||||||
Hosts map[netip.Addr]*HostInfo
|
Hosts map[netip.Addr]*HostInfo
|
||||||
preferredRanges atomic.Pointer[[]netip.Prefix]
|
preferredRanges atomic.Pointer[[]netip.Prefix]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
|
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
|
||||||
@@ -313,7 +315,7 @@ type cachedPacketMetrics struct {
|
|||||||
dropped metrics.Counter
|
dropped metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap {
|
||||||
hm := newHostMap(l)
|
hm := newHostMap(l)
|
||||||
|
|
||||||
hm.reload(c, true)
|
hm.reload(c, true)
|
||||||
@@ -321,13 +323,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
|||||||
hm.reload(c, false)
|
hm.reload(c, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
l.WithField("preferredRanges", hm.GetPreferredRanges()).
|
l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges())
|
||||||
Info("Main HostMap created")
|
|
||||||
|
|
||||||
return hm
|
return hm
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostMap(l *logrus.Logger) *HostMap {
|
func newHostMap(l *slog.Logger) *HostMap {
|
||||||
return &HostMap{
|
return &HostMap{
|
||||||
Indexes: map[uint32]*HostInfo{},
|
Indexes: map[uint32]*HostInfo{},
|
||||||
Relays: map[uint32]*HostInfo{},
|
Relays: map[uint32]*HostInfo{},
|
||||||
@@ -346,7 +347,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
|
|||||||
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
|
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
|
hm.l.Warn("Failed to parse preferred ranges, ignoring",
|
||||||
|
"error", err,
|
||||||
|
"range", rawPreferredRanges,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,7 +359,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
|
|||||||
|
|
||||||
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
|
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
|
||||||
if !initial {
|
if !initial {
|
||||||
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
|
hm.l.Info("preferred_ranges changed",
|
||||||
|
"oldPreferredRanges", *oldRanges,
|
||||||
|
"newPreferredRanges", preferredRanges,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -488,10 +495,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad
|
|||||||
hm.Indexes = map[uint32]*HostInfo{}
|
hm.Indexes = map[uint32]*HostInfo{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
|
hm.l.Debug("Hostmap hostInfo deleted",
|
||||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
"hostMap", m{"mapTotalSize": len(hm.Hosts),
|
||||||
Debug("Hostmap hostInfo deleted")
|
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLastHostinfo {
|
if isLastHostinfo {
|
||||||
@@ -604,9 +612,9 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
|
|||||||
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
|
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
|
||||||
// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
|
// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
|
||||||
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||||
if f.serveDns {
|
if f.dnsServer != nil {
|
||||||
remoteCert := hostinfo.ConnectionState.peerCert
|
remoteCert := hostinfo.ConnectionState.peerCert
|
||||||
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
f.dnsServer.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
||||||
}
|
}
|
||||||
for _, addr := range hostinfo.vpnAddrs {
|
for _, addr := range hostinfo.vpnAddrs {
|
||||||
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
||||||
@@ -615,10 +623,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
|||||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
hm.l.Debug("Hostmap vpnIp added",
|
||||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
|
"hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
||||||
Debug("Hostmap vpnIp added")
|
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -784,18 +793,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
// logger returns a derived slog.Logger with per-hostinfo fields pre-bound.
|
||||||
|
func (i *HostInfo) logger(l *slog.Logger) *slog.Logger {
|
||||||
if i == nil {
|
if i == nil {
|
||||||
return logrus.NewEntry(l)
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
li := l.WithField("vpnAddrs", i.vpnAddrs).
|
li := l.With(
|
||||||
WithField("localIndex", i.localIndexId).
|
"vpnAddrs", i.vpnAddrs,
|
||||||
WithField("remoteIndex", i.remoteIndexId)
|
"localIndex", i.localIndexId,
|
||||||
|
"remoteIndex", i.remoteIndexId,
|
||||||
|
)
|
||||||
|
|
||||||
if connState := i.ConnectionState; connState != nil {
|
if connState := i.ConnectionState; connState != nil {
|
||||||
if peerCert := connState.peerCert; peerCert != nil {
|
if peerCert := connState.peerCert; peerCert != nil {
|
||||||
li = li.WithField("certName", peerCert.Certificate.Name())
|
li = li.With("certName", peerCert.Certificate.Name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -804,14 +816,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
|||||||
|
|
||||||
// Utility functions
|
// Utility functions
|
||||||
|
|
||||||
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||||
//FIXME: This function is pretty garbage
|
//FIXME: This function is pretty garbage
|
||||||
var finalAddrs []netip.Addr
|
var finalAddrs []netip.Addr
|
||||||
ifaces, _ := net.Interfaces()
|
ifaces, _ := net.Interfaces()
|
||||||
for _, i := range ifaces {
|
for _, i := range ifaces {
|
||||||
allow := allowList.AllowName(i.Name)
|
allow := allowList.AllowName(i.Name)
|
||||||
if l.Level >= logrus.TraceLevel {
|
if l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName")
|
l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName",
|
||||||
|
"interfaceName", i.Name,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -829,8 +844,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !addr.IsValid() {
|
if !addr.IsValid() {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
|
l.Debug("addr was invalid", "localAddr", rawAddr)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -838,8 +853,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
|||||||
|
|
||||||
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
|
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
|
||||||
isAllowed := allowList.Allow(addr)
|
isAllowed := allowList.Allow(addr)
|
||||||
if l.Level >= logrus.TraceLevel {
|
if l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
|
l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow",
|
||||||
|
"localAddr", addr,
|
||||||
|
"allowed", isAllowed,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if !isAllowed {
|
if !isAllowed {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||||||
|
|
||||||
func TestHostMap_reload(t *testing.T) {
|
func TestHostMap_reload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(test.NewLogger())
|
||||||
|
|
||||||
hm := NewHostMapFromConfig(l, c)
|
hm := NewHostMapFromConfig(l, c)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build e2e_testing
|
//go:build e2e_testing
|
||||||
// +build e2e_testing
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
119
inside.go
119
inside.go
@@ -1,9 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
@@ -14,8 +15,11 @@ import (
|
|||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
f.l.Debug("Error while validating outbound packet",
|
||||||
|
"packet", packet,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -35,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
if immediatelyForwardToSelf {
|
if immediatelyForwardToSelf {
|
||||||
_, err := f.readers[q].Write(packet)
|
_, err := f.readers[q].Write(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to forward to tun")
|
f.l.Error("Failed to forward to tun", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Otherwise, drop. On linux, we should never see these packets - Linux
|
// Otherwise, drop. On linux, we should never see these packets - Linux
|
||||||
@@ -54,10 +58,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
|
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
|
||||||
WithField("fwPacket", fwPacket).
|
"vpnAddr", fwPacket.RemoteAddr,
|
||||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
"fwPacket", fwPacket,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -72,11 +77,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).
|
hostinfo.logger(f.l).Debug("dropping outbound packet",
|
||||||
WithField("fwPacket", fwPacket).
|
"fwPacket", fwPacket,
|
||||||
WithField("reason", dropReason).
|
"reason", dropReason,
|
||||||
Debugln("dropping outbound packet")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,7 +98,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
|||||||
|
|
||||||
_, err := f.readers[q].Write(out)
|
_, err := f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.Error("Failed to write to tun", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,11 +113,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(out) > iputil.MaxRejectPacketSize {
|
if len(out) > iputil.MaxRejectPacketSize {
|
||||||
if f.l.GetLevel() >= logrus.InfoLevel {
|
if f.l.Enabled(context.Background(), slog.LevelInfo) {
|
||||||
f.l.
|
f.l.Info("rejectOutside: packet too big, not sending",
|
||||||
WithField("packet", packet).
|
"packet", packet,
|
||||||
WithField("outPacket", out).
|
"outPacket", out,
|
||||||
Info("rejectOutside: packet too big, not sending")
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -184,10 +189,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
|
|||||||
// This would also need to interact with unsafe_route updates through reloading the config or
|
// This would also need to interact with unsafe_route updates through reloading the config or
|
||||||
// use of the use_system_route_table option
|
// use of the use_system_route_table option
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("destination", destinationAddr).
|
f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways",
|
||||||
WithField("originalGateway", gatewayAddr).
|
"destination", destinationAddr,
|
||||||
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
|
"originalGateway", gatewayAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range gateways {
|
for i := range gateways {
|
||||||
@@ -213,17 +219,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
fp := &firewall.Packet{}
|
fp := &firewall.Packet{}
|
||||||
err := newPacket(p, false, fp)
|
err := newPacket(p, false, fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("fwPacket", fp).
|
f.l.Debug("dropping cached packet",
|
||||||
WithField("reason", dropReason).
|
"fwPacket", fp,
|
||||||
Debugln("dropping cached packet")
|
"reason", dropReason,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -239,9 +246,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message
|
|||||||
})
|
})
|
||||||
|
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("vpnAddr", vpnAddr).
|
f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes",
|
||||||
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
|
"vpnAddr", vpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -297,12 +305,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
if noiseutil.EncryptLockNeeded {
|
if noiseutil.EncryptLockNeeded {
|
||||||
via.ConnectionState.writeLock.Unlock()
|
via.ConnectionState.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
via.logger(f.l).
|
via.logger(f.l).Error("SendVia out buffer not large enough for relay",
|
||||||
WithField("outCap", cap(out)).
|
"outCap", cap(out),
|
||||||
WithField("payloadLen", len(ad)).
|
"payloadLen", len(ad),
|
||||||
WithField("headerLen", len(out)).
|
"headerLen", len(out),
|
||||||
WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()).
|
"cipherOverhead", via.ConnectionState.eKey.Overhead(),
|
||||||
Error("SendVia out buffer not large enough for relay")
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -322,12 +330,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
via.ConnectionState.writeLock.Unlock()
|
via.ConnectionState.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
|
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = f.writers[0].WriteTo(out, via.remote)
|
err = f.writers[0].WriteTo(out, via.remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
|
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
|
||||||
}
|
}
|
||||||
f.connectionManager.RelayUsed(relay.LocalIndex)
|
f.connectionManager.RelayUsed(relay.LocalIndex)
|
||||||
}
|
}
|
||||||
@@ -366,8 +374,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||||
hostinfo.lastRebindCount = f.rebindCount
|
hostinfo.lastRebindCount = f.rebindCount
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
f.l.Debug("Lighthouse update triggered for punch due to rebind counter",
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,24 +387,30 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
ci.writeLock.Unlock()
|
ci.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
|
||||||
WithField("udpAddr", remote).WithField("counter", c).
|
"error", err,
|
||||||
WithField("attemptedCounter", c).
|
"udpAddr", remote,
|
||||||
Error("Failed to encrypt outgoing packet")
|
"counter", c,
|
||||||
|
"attemptedCounter", c,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if remote.IsValid() {
|
if remote.IsValid() {
|
||||||
err = f.writers[q].WriteTo(out, remote)
|
err = f.writers[q].WriteTo(out, remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
"error", err,
|
||||||
|
"udpAddr", remote,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else if hostinfo.remote.IsValid() {
|
} else if hostinfo.remote.IsValid() {
|
||||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
"error", err,
|
||||||
|
"udpAddr", remote,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Try to send via a relay
|
// Try to send via a relay
|
||||||
@@ -402,7 +418,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.relayState.DeleteRelay(relayIP)
|
hostinfo.relayState.DeleteRelay(relayIP)
|
||||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
|
||||||
|
"relay", relayIP,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||||
// +build darwin dragonfly freebsd netbsd openbsd
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd
|
//go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd
|
||||||
// +build !darwin,!dragonfly,!freebsd,!netbsd,!openbsd
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
166
interface.go
166
interface.go
@@ -6,15 +6,15 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"sync"
|
||||||
"runtime"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@@ -31,7 +31,7 @@ type InterfaceConfig struct {
|
|||||||
pki *PKI
|
pki *PKI
|
||||||
Cipher string
|
Cipher string
|
||||||
Firewall *Firewall
|
Firewall *Firewall
|
||||||
ServeDns bool
|
DnsServer *dnsServer
|
||||||
HandshakeManager *HandshakeManager
|
HandshakeManager *HandshakeManager
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
@@ -48,7 +48,7 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
@@ -59,7 +59,7 @@ type Interface struct {
|
|||||||
firewall *Firewall
|
firewall *Firewall
|
||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
handshakeManager *HandshakeManager
|
handshakeManager *HandshakeManager
|
||||||
serveDns bool
|
dnsServer *dnsServer
|
||||||
createTime time.Time
|
createTime time.Time
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
myBroadcastAddrsTable *bart.Lite
|
myBroadcastAddrsTable *bart.Lite
|
||||||
@@ -87,14 +87,22 @@ type Interface struct {
|
|||||||
|
|
||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// fatalErr holds the first unexpected reader error that caused shutdown.
|
||||||
|
// nil means "no fatal error" (yet)
|
||||||
|
fatalErr atomic.Pointer[error]
|
||||||
|
// triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr
|
||||||
|
triggerShutdown func()
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
type EncWriter interface {
|
||||||
@@ -165,12 +173,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
|
ctx: ctx,
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
hostMap: c.HostMap,
|
hostMap: c.HostMap,
|
||||||
outside: c.Outside,
|
outside: c.Outside,
|
||||||
inside: c.Inside,
|
inside: c.Inside,
|
||||||
firewall: c.Firewall,
|
firewall: c.Firewall,
|
||||||
serveDns: c.ServeDns,
|
dnsServer: c.DnsServer,
|
||||||
handshakeManager: c.HandshakeManager,
|
handshakeManager: c.HandshakeManager,
|
||||||
createTime: time.Now(),
|
createTime: time.Now(),
|
||||||
lightHouse: c.lightHouse,
|
lightHouse: c.lightHouse,
|
||||||
@@ -211,19 +220,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
// activate creates the interface on the host. After the interface is created, any
|
// activate creates the interface on the host. After the interface is created, any
|
||||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||||
// the interface isn't going to process anything until run() is called.
|
// the interface isn't going to process anything until run() is called.
|
||||||
func (f *Interface) activate() {
|
func (f *Interface) activate() error {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to get udp listen address")
|
f.l.Error("Failed to get udp listen address", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
|
f.l.Info("Nebula interface is active",
|
||||||
WithField("build", f.version).WithField("udpAddr", addr).
|
"interface", f.inside.Name(),
|
||||||
WithField("boringcrypto", boringEnabled()).
|
"networks", f.myVpnNetworks,
|
||||||
WithField("fips140", fips140.Enabled()).
|
"build", f.version,
|
||||||
Info("Nebula interface is active")
|
"udpAddr", addr,
|
||||||
|
"boringcrypto", boringEnabled(),
|
||||||
|
"fips140", fips140.Enabled(),
|
||||||
|
)
|
||||||
|
|
||||||
if f.routines > 1 {
|
if f.routines > 1 {
|
||||||
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
||||||
@@ -240,33 +252,58 @@ func (f *Interface) activate() {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.inside.Activate(); err != nil {
|
f.wg.Add(1) // for us to wait on Close() to return
|
||||||
|
if err = f.inside.Activate(); err != nil {
|
||||||
|
f.wg.Done()
|
||||||
f.inside.Close()
|
f.inside.Close()
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) run() {
|
func (f *Interface) run() (func() error, error) {
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenOut(i)
|
f.wg.Go(func() {
|
||||||
|
f.listenOut(i)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenIn(f.readers[i], i)
|
f.wg.Go(func() {
|
||||||
|
f.listenIn(f.readers[i], i)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() error {
|
||||||
|
f.wg.Wait()
|
||||||
|
if e := f.fatalErr.Load(); e != nil {
|
||||||
|
return *e
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one
|
||||||
|
func (f *Interface) onFatal(err error) {
|
||||||
|
swapped := f.fatalErr.CompareAndSwap(nil, &err)
|
||||||
|
if !swapped {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.triggerShutdown != nil {
|
||||||
|
f.triggerShutdown()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[i]
|
||||||
@@ -274,42 +311,47 @@ func (f *Interface) listenOut(i int) {
|
|||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
plaintext := make([]byte, udp.MTU)
|
plaintext := make([]byte, udp.MTU)
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err != nil && !f.closed.Load() {
|
||||||
|
f.l.Error("Error while reading inbound packet, closing", "error", err)
|
||||||
|
f.onFatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.l.Debug("underlay reader is done", "reader", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
|
||||||
|
|
||||||
packet := make([]byte, mtu)
|
packet := make([]byte, mtu)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
n, err := reader.Read(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
if !f.closed.Load() {
|
||||||
return
|
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
|
||||||
|
f.onFatal(err)
|
||||||
}
|
}
|
||||||
|
break
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
|
||||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
|
||||||
os.Exit(2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.l.Debug("overlay reader is done", "reader", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
@@ -329,7 +371,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
|
|||||||
if initial || c.HasChanged("pki.disconnect_invalid") {
|
if initial || c.HasChanged("pki.disconnect_invalid") {
|
||||||
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
|
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
|
||||||
if !initial {
|
if !initial {
|
||||||
f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
|
f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -343,7 +385,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
|
|
||||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
f.l.Error("Error while creating firewall during reload", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -356,10 +398,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
||||||
// safe and just reset conntrack in this case.
|
// safe and just reset conntrack in this case.
|
||||||
if fw.rulesVersion == 0 {
|
if fw.rulesVersion == 0 {
|
||||||
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
|
f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack",
|
||||||
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
|
"firewallHashes", fw.GetRuleHashes(),
|
||||||
WithField("rulesVersion", fw.rulesVersion).
|
"oldFirewallHashes", oldFw.GetRuleHashes(),
|
||||||
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
"rulesVersion", fw.rulesVersion,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
fw.Conntrack = conntrack
|
fw.Conntrack = conntrack
|
||||||
}
|
}
|
||||||
@@ -367,10 +410,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
f.firewall = fw
|
f.firewall = fw
|
||||||
|
|
||||||
oldFw.Destroy()
|
oldFw.Destroy()
|
||||||
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
|
f.l.Info("New firewall has been installed",
|
||||||
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
|
"firewallHashes", fw.GetRuleHashes(),
|
||||||
WithField("rulesVersion", fw.rulesVersion).
|
"oldFirewallHashes", oldFw.GetRuleHashes(),
|
||||||
Info("New firewall has been installed")
|
"rulesVersion", fw.rulesVersion,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadSendRecvError(c *config.C) {
|
func (f *Interface) reloadSendRecvError(c *config.C) {
|
||||||
@@ -392,8 +436,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
|
f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String())
|
||||||
Info("Loaded send_recv_error config")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -416,8 +459,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()).
|
f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String())
|
||||||
Info("Loaded accept_recv_error config")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -484,23 +526,23 @@ func (f *Interface) GetCertState() *CertState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
|
var errs []error
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
for _, u := range f.writers {
|
// Release the udp readers
|
||||||
|
for i, u := range f.writers {
|
||||||
err := u.Close()
|
err := u.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Error while closing udp socket")
|
f.l.Error("Error while closing udp socket", "error", err, "writer", i)
|
||||||
}
|
errs = append(errs, err)
|
||||||
}
|
|
||||||
for i, r := range f.readers {
|
|
||||||
if i == 0 {
|
|
||||||
continue // f.readers[0] is f.inside, which we want to save for last
|
|
||||||
}
|
|
||||||
if err := r.Close(); err != nil {
|
|
||||||
f.l.WithError(err).Error("Error while closing tun reader")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the tun device
|
// Release the tun device (closing the tun also closes all readers)
|
||||||
return f.inside.Close()
|
closeErr := f.inside.Close()
|
||||||
|
if closeErr != nil {
|
||||||
|
errs = append(errs, closeErr)
|
||||||
|
}
|
||||||
|
f.wg.Done()
|
||||||
|
return errors.Join(errs...)
|
||||||
}
|
}
|
||||||
|
|||||||
259
lighthouse.go
259
lighthouse.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -15,10 +16,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
@@ -69,18 +70,19 @@ type LightHouse struct {
|
|||||||
// Addr's of relays that can be used by peers to access me
|
// Addr's of relays that can be used by peers to access me
|
||||||
relaysForMe atomic.Pointer[[]netip.Addr]
|
relaysForMe atomic.Pointer[[]netip.Addr]
|
||||||
|
|
||||||
queryChan chan netip.Addr
|
updateTrigger chan struct{}
|
||||||
|
queryChan chan netip.Addr
|
||||||
|
|
||||||
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
|
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
|
||||||
|
|
||||||
metrics *MessageMetrics
|
metrics *MessageMetrics
|
||||||
metricHolepunchTx metrics.Counter
|
metricHolepunchTx metrics.Counter
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
||||||
// addrMap should be nil unless this is during a config reload
|
// addrMap should be nil unless this is during a config reload
|
||||||
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
||||||
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||||
nebulaPort := uint32(c.GetInt("listen.port", 0))
|
nebulaPort := uint32(c.GetInt("listen.port", 0))
|
||||||
if amLighthouse && nebulaPort == 0 {
|
if amLighthouse && nebulaPort == 0 {
|
||||||
@@ -105,6 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||||||
nebulaPort: nebulaPort,
|
nebulaPort: nebulaPort,
|
||||||
punchConn: pc,
|
punchConn: pc,
|
||||||
punchy: p,
|
punchy: p,
|
||||||
|
updateTrigger: make(chan struct{}, 1),
|
||||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@@ -131,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||||||
case *util.ContextualError:
|
case *util.ContextualError:
|
||||||
v.Log(l)
|
v.Log(l)
|
||||||
case error:
|
case error:
|
||||||
l.WithError(err).Error("failed to reload lighthouse")
|
l.Error("failed to reload lighthouse", "error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -203,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
||||||
addr := addrs[0].Unmap()
|
addr := addrs[0].Unmap()
|
||||||
if lh.myVpnNetworksTable.Contains(addr) {
|
if lh.myVpnNetworksTable.Contains(addr) {
|
||||||
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range",
|
||||||
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
"addr", rawAddr,
|
||||||
|
"entry", i+1,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
|
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
|
||||||
|
|
||||||
if !initial {
|
if !initial {
|
||||||
lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load())
|
lh.l.Info("lighthouse.interval changed",
|
||||||
|
"interval", lh.interval.Load(),
|
||||||
|
)
|
||||||
|
|
||||||
if lh.updateCancel != nil {
|
if lh.updateCancel != nil {
|
||||||
// May not always have a running routine
|
// May not always have a running routine
|
||||||
@@ -316,6 +323,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
if !initial {
|
if !initial {
|
||||||
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
||||||
lh.l.Info("lighthouse.hosts has changed")
|
lh.l.Info("lighthouse.hosts has changed")
|
||||||
|
lh.TriggerUpdate()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
for _, v := range c.GetStringSlice("relay.relays", nil) {
|
for _, v := range c.GetStringSlice("relay.relays", nil) {
|
||||||
configRIP, err := netip.ParseAddr(v)
|
configRIP, err := netip.ParseAddr(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed")
|
lh.l.Warn("Parse relay from config failed",
|
||||||
|
"relay", v,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
lh.l.WithField("relay", v).Info("Read relay from config")
|
lh.l.Info("Read relay from config", "relay", v)
|
||||||
relaysForMe = append(relaysForMe, configRIP)
|
relaysForMe = append(relaysForMe, configRIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -360,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not",
|
||||||
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
"vpnAddr", addr,
|
||||||
|
"networks", lh.myVpnNetworks,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
out[i] = addr
|
out[i] = addr
|
||||||
}
|
}
|
||||||
@@ -432,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work",
|
||||||
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
"vpnAddr", vpnAddr,
|
||||||
|
"networks", lh.myVpnNetworks,
|
||||||
|
"entry", i+1,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
vals, ok := v.([]any)
|
vals, ok := v.([]any)
|
||||||
@@ -534,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
|||||||
lh.Lock()
|
lh.Lock()
|
||||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||||
if ok {
|
if ok {
|
||||||
|
debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug)
|
||||||
for _, addr := range allVpnAddrs {
|
for _, addr := range allVpnAddrs {
|
||||||
srm := lh.addrMap[addr]
|
srm := lh.addrMap[addr]
|
||||||
if srm == rm {
|
if srm == rm {
|
||||||
delete(lh.addrMap, addr)
|
delete(lh.addrMap, addr)
|
||||||
if lh.l.Level >= logrus.DebugLevel {
|
if debugEnabled {
|
||||||
lh.l.Debugf("deleting %s from lighthouse.", addr)
|
lh.l.Debug("deleting from lighthouse", "vpnAddr", addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -656,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
|||||||
|
|
||||||
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
||||||
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||||
Trace("remoteAllowList.Allow")
|
"vpnAddrs", vpnAddrs,
|
||||||
|
"udpAddr", to,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
return false
|
return false
|
||||||
@@ -675,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
|||||||
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
|
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
|
||||||
udpAddr := protoV4AddrPortToNetAddrPort(to)
|
udpAddr := protoV4AddrPortToNetAddrPort(to)
|
||||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
|
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||||
Trace("remoteAllowList.Allow")
|
"vpnAddr", vpnAddr,
|
||||||
|
"udpAddr", udpAddr,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -695,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
|
|||||||
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
|
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
|
||||||
udpAddr := protoV6AddrPortToNetAddrPort(to)
|
udpAddr := protoV6AddrPortToNetAddrPort(to)
|
||||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
|
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||||
Trace("remoteAllowList.Allow")
|
"vpnAddr", vpnAddr,
|
||||||
|
"udpAddr", udpAddr,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -713,21 +739,14 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
|||||||
|
|
||||||
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
||||||
l := lh.GetLighthouses()
|
l := lh.GetLighthouses()
|
||||||
for i := range l {
|
return slices.Contains(l, vpnAddr)
|
||||||
if l[i] == vpnAddr {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
||||||
l := lh.GetLighthouses()
|
l := lh.GetLighthouses()
|
||||||
for i := range vpnAddrs {
|
for i := range vpnAddrs {
|
||||||
for j := range l {
|
if slices.Contains(l, vpnAddrs[i]) {
|
||||||
if l[j] == vpnAddrs[i] {
|
return true
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -779,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
|
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if !addr.Is4() {
|
if !addr.Is4() {
|
||||||
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol",
|
||||||
Error("Can't query lighthouse for v6 address using a v1 protocol")
|
"queryVpnAddr", addr,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -791,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
|
|
||||||
v1Query, err = msg.Marshal()
|
v1Query, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("queryVpnAddr", addr).
|
lh.l.Error("Failed to marshal lighthouse v1 query payload",
|
||||||
WithField("lighthouseAddr", lhVpnAddr).
|
"error", err,
|
||||||
Error("Failed to marshal lighthouse v1 query payload")
|
"queryVpnAddr", addr,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -808,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
|
|
||||||
v2Query, err = msg.Marshal()
|
v2Query, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("queryVpnAddr", addr).
|
lh.l.Error("Failed to marshal lighthouse v2 query payload",
|
||||||
WithField("lighthouseAddr", lhVpnAddr).
|
"error", err,
|
||||||
Error("Failed to marshal lighthouse v2 query payload")
|
"queryVpnAddr", addr,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -819,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
queried++
|
queried++
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v)
|
lh.l.Debug("unsupported protocol version",
|
||||||
|
"op", "query",
|
||||||
|
"queryVpnAddr", addr,
|
||||||
|
"version", v,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -848,11 +877,24 @@ func (lh *LightHouse) StartUpdateWorker() {
|
|||||||
return
|
return
|
||||||
case <-clockSource.C:
|
case <-clockSource.C:
|
||||||
continue
|
continue
|
||||||
|
case <-lh.updateTrigger:
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TriggerUpdate requests an immediate lighthouse update. This is a non-blocking
|
||||||
|
// operation intended to be called after a handshake completes with a lighthouse,
|
||||||
|
// so the lighthouse has our current addresses without waiting for the next
|
||||||
|
// periodic update.
|
||||||
|
func (lh *LightHouse) TriggerUpdate() {
|
||||||
|
select {
|
||||||
|
case lh.updateTrigger <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) SendUpdate() {
|
func (lh *LightHouse) SendUpdate() {
|
||||||
var v4 []*V4AddrPort
|
var v4 []*V4AddrPort
|
||||||
var v6 []*V6AddrPort
|
var v6 []*V6AddrPort
|
||||||
@@ -898,8 +940,9 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if v1Update == nil {
|
if v1Update == nil {
|
||||||
if !lh.myVpnNetworks[0].Addr().Is4() {
|
if !lh.myVpnNetworks[0].Addr().Is4() {
|
||||||
lh.l.WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address",
|
||||||
Warn("cannot update lighthouse using v1 protocol without an IPv4 address")
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var relays []uint32
|
var relays []uint32
|
||||||
@@ -923,8 +966,10 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
|
|
||||||
v1Update, err = msg.Marshal()
|
v1Update, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Error("Error while marshaling for lighthouse v1 update",
|
||||||
Error("Error while marshaling for lighthouse v1 update")
|
"error", err,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -950,8 +995,10 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
|
|
||||||
v2Update, err = msg.Marshal()
|
v2Update, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Error("Error while marshaling for lighthouse v2 update",
|
||||||
Error("Error while marshaling for lighthouse v2 update")
|
"error", err,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -960,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
updated++
|
updated++
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v)
|
lh.l.Debug("unsupported protocol version",
|
||||||
|
"op", "update",
|
||||||
|
"version", v,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -974,7 +1024,7 @@ type LightHouseHandler struct {
|
|||||||
out []byte
|
out []byte
|
||||||
pb []byte
|
pb []byte
|
||||||
meta *NebulaMeta
|
meta *NebulaMeta
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
|
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
|
||||||
@@ -1023,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
|||||||
n := lhh.resetMeta()
|
n := lhh.resetMeta()
|
||||||
err := n.Unmarshal(p)
|
err := n.Unmarshal(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.Error("Failed to unmarshal lighthouse packet",
|
||||||
Error("Failed to unmarshal lighthouse packet")
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"udpAddr", rAddr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.Details == nil {
|
if n.Details == nil {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.Error("Invalid lighthouse update",
|
||||||
Error("Invalid lighthouse update")
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"udpAddr", rAddr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1058,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
|||||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
||||||
// Exit if we don't answer queries
|
// Exit if we don't answer queries
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugln("I don't answer queries, but received from: ", addr)
|
lhh.l.Debug("I don't answer queries, but received one", "from", addr)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
lhh.l.Debug("Dropping malformed HostQuery",
|
||||||
Debugln("Dropping malformed HostQuery")
|
"from", fromVpnAddrs,
|
||||||
|
"details", n.Details,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
lhh.l.Debug("invalid vpn addr for v1 handleHostQuery",
|
||||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"queryVpnAddr", queryVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1101,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
lhh.l.Error("Failed to marshal lighthouse host query reply",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1129,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
if ok {
|
if ok {
|
||||||
whereToPunch = newDest
|
whereToPunch = newDest
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
lhh.l.Debug("unable to punch to host, no addresses in common",
|
||||||
|
"to", crt.Networks(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1156,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for")
|
lhh.l.Error("Failed to marshal lighthouse host was queried for",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1198,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
|||||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
lhh.l.Debug("unsupported protocol version",
|
||||||
|
"op", "coalesceAnswers",
|
||||||
|
"version", v,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1212,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
|
|
||||||
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
lhh.l.Error("dropping malformed HostQueryReply",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1238,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1262,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
|
|
||||||
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
||||||
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
lhh.l.Debug("Host sent invalid update",
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"answer", detailsVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1285,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
switch useVersion {
|
switch useVersion {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
if !fromVpnAddrs[0].Is4() {
|
if !fromVpnAddrs[0].Is4() {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message",
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrB := fromVpnAddrs[0].As4()
|
vpnAddrB := fromVpnAddrs[0].As4()
|
||||||
@@ -1293,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
case cert.Version2:
|
case cert.Version2:
|
||||||
// do nothing, we want to send a blank message
|
// do nothing, we want to send a blank message
|
||||||
default:
|
default:
|
||||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
lhh.l.Error("invalid protocol version", "useVersion", useVersion)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ln, err := n.MarshalTo(lhh.pb)
|
ln, err := n.MarshalTo(lhh.pb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack")
|
lhh.l.Error("Failed to marshal lighthouse host update ack",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1316,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
|
|
||||||
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
lhh.l.Debug("dropping invalid HostPunchNotification",
|
||||||
|
"details", n.Details,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1334,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
lhh.l.Debug("Punching",
|
||||||
|
"vpnPeer", vpnPeer,
|
||||||
|
"logVpnAddr", logVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1360,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
if lhh.lh.punchy.GetRespond() {
|
if lhh.lh.punchy.GetRespond() {
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
lhh.l.Debug("Sending a nebula test packet",
|
||||||
|
"vpnAddr", detailsVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||||
|
|||||||
45
logger.go
45
logger.go
@@ -1,45 +0,0 @@
|
|||||||
package nebula
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func configLogger(l *logrus.Logger, c *config.C) error {
|
|
||||||
// set up our logging level
|
|
||||||
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
|
|
||||||
}
|
|
||||||
l.SetLevel(logLevel)
|
|
||||||
|
|
||||||
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
|
||||||
timestampFormat := c.GetString("logging.timestamp_format", "")
|
|
||||||
fullTimestamp := (timestampFormat != "")
|
|
||||||
if timestampFormat == "" {
|
|
||||||
timestampFormat = time.RFC3339
|
|
||||||
}
|
|
||||||
|
|
||||||
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
|
||||||
switch logFormat {
|
|
||||||
case "text":
|
|
||||||
l.Formatter = &logrus.TextFormatter{
|
|
||||||
TimestampFormat: timestampFormat,
|
|
||||||
FullTimestamp: fullTimestamp,
|
|
||||||
DisableTimestamp: disableTimestamp,
|
|
||||||
}
|
|
||||||
case "json":
|
|
||||||
l.Formatter = &logrus.JSONFormatter{
|
|
||||||
TimestampFormat: timestampFormat,
|
|
||||||
DisableTimestamp: disableTimestamp,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
233
logging/logger.go
Normal file
233
logging/logger.go
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
// Package logging wires the nebula runtime-reconfigurable slog handler used
|
||||||
|
// by nebula.Main and the nebula CLI binaries. Callers build a logger with
|
||||||
|
// NewLogger, then call ApplyConfig at startup and from a config reload
|
||||||
|
// callback to push logging.level, logging.format, and
|
||||||
|
// logging.disable_timestamp changes onto the logger without rebuilding it.
|
||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config is the subset of *config.C that ApplyConfig reads. Declaring it
|
||||||
|
// here keeps the logging package from depending on config directly, which
|
||||||
|
// would cycle through the shared test helpers (test.NewLogger imports
|
||||||
|
// logging, and config's tests import test). *config.C satisfies this
|
||||||
|
// interface structurally with no adapter.
|
||||||
|
type Config interface {
|
||||||
|
GetString(key, def string) string
|
||||||
|
GetBool(key string, def bool) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// LevelTrace is a custom slog level below Debug, used when logging.level is
|
||||||
|
// "trace". slog has no builtin trace level; the value is one step below
|
||||||
|
// slog.LevelDebug in slog's 4-point spacing.
|
||||||
|
const LevelTrace = slog.Level(-8)
|
||||||
|
|
||||||
|
// NewLogger returns a *slog.Logger whose level, format, and timestamp
|
||||||
|
// emission can be reconfigured at runtime via ApplyConfig and the SSH debug
|
||||||
|
// commands. The default configuration is info-level text output so log
|
||||||
|
// calls made before ApplyConfig runs still produce output. Timestamps
|
||||||
|
// follow slog's default RFC3339Nano format; set logging.disable_timestamp
|
||||||
|
// in config to suppress them.
|
||||||
|
//
|
||||||
|
// ApplyConfig and the SSH commands discover the reconfig surface via
|
||||||
|
// structural type-assertion on l.Handler(), so replacement implementations
|
||||||
|
// (tests, platform-specific sinks) need only implement the subset of
|
||||||
|
// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)}
|
||||||
|
// they care about. Callers that pass a plain *slog.Logger without these
|
||||||
|
// methods get a silent no-op; reconfiguration is always opt-in.
|
||||||
|
func NewLogger(w io.Writer) *slog.Logger {
|
||||||
|
return slog.New(NewHandler(w))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler builds the *Handler that NewLogger wraps. Exported for
|
||||||
|
// platform-specific sinks (notably cmd/nebula-service/logs_windows.go)
|
||||||
|
// that want to wrap the handler with extra behavior, such as tagging each
|
||||||
|
// record with its Event Log severity, while still benefiting from all the
|
||||||
|
// level / format / timestamp / WithAttrs machinery implemented here.
|
||||||
|
func NewHandler(w io.Writer) *Handler {
|
||||||
|
root := &handlerRoot{}
|
||||||
|
root.level.Set(slog.LevelInfo)
|
||||||
|
opts := &slog.HandlerOptions{Level: &root.level}
|
||||||
|
return &Handler{
|
||||||
|
root: root,
|
||||||
|
text: slog.NewTextHandler(w, opts),
|
||||||
|
json: slog.NewJSONHandler(w, opts),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlerRoot carries the reconfiguration state shared by every logger
|
||||||
|
// derived from a NewHandler call. All fields are consulted on the log
|
||||||
|
// path and updated lock-free.
|
||||||
|
type handlerRoot struct {
|
||||||
|
level slog.LevelVar
|
||||||
|
disableTimestamp atomic.Bool
|
||||||
|
// jsonMode picks which of the pre-derived inner handlers Handler.Handle
|
||||||
|
// dispatches to. Flipping it propagates instantly to every derived logger
|
||||||
|
// without rebuilding or chain-replaying anything.
|
||||||
|
jsonMode atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler is the slog.Handler returned by NewHandler. It holds two
|
||||||
|
// pre-derived slog handlers -- one text, one json -- both built from the
|
||||||
|
// same accumulated WithAttrs/WithGroup state. Handle picks which one to
|
||||||
|
// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes
|
||||||
|
// effect immediately across the whole process without having to rebuild
|
||||||
|
// any derived loggers.
|
||||||
|
type Handler struct {
|
||||||
|
root *handlerRoot
|
||||||
|
text slog.Handler
|
||||||
|
json slog.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) Enabled(_ context.Context, l slog.Level) bool {
|
||||||
|
return h.root.level.Level() <= l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) Handle(ctx context.Context, r slog.Record) error {
|
||||||
|
if h.root.disableTimestamp.Load() {
|
||||||
|
r.Time = time.Time{}
|
||||||
|
}
|
||||||
|
if h.root.jsonMode.Load() {
|
||||||
|
return h.json.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
return h.text.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||||
|
if len(attrs) == 0 {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return &Handler{
|
||||||
|
root: h.root,
|
||||||
|
text: h.text.WithAttrs(attrs),
|
||||||
|
json: h.json.WithAttrs(attrs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) WithGroup(name string) slog.Handler {
|
||||||
|
if name == "" {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return &Handler{
|
||||||
|
root: h.root,
|
||||||
|
text: h.text.WithGroup(name),
|
||||||
|
json: h.json.WithGroup(name),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLevel updates the effective log level. Propagates to every derived
|
||||||
|
// logger via the shared LevelVar.
|
||||||
|
func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) }
|
||||||
|
|
||||||
|
// GetLevel reports the current log level.
|
||||||
|
func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() }
|
||||||
|
|
||||||
|
// SetFormat flips the output format atomically. Valid formats are "text"
|
||||||
|
// and "json". Every derived logger sees the new format on its next Handle
|
||||||
|
// call; no rebuild or registration is required.
|
||||||
|
func (h *Handler) SetFormat(format string) error {
|
||||||
|
switch format {
|
||||||
|
case "text":
|
||||||
|
h.root.jsonMode.Store(false)
|
||||||
|
case "json":
|
||||||
|
h.root.jsonMode.Store(true)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFormat reports the currently selected format name.
|
||||||
|
func (h *Handler) GetFormat() string {
|
||||||
|
if h.root.jsonMode.Load() {
|
||||||
|
return "json"
|
||||||
|
}
|
||||||
|
return "text"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableTimestamp toggles whether Handle zeroes r.Time before
|
||||||
|
// dispatching (slog's builtin text/json handlers skip emitting the time
|
||||||
|
// attribute on a zero time).
|
||||||
|
func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) }
|
||||||
|
|
||||||
|
// ApplyConfig reads logging.level, logging.format, and (optionally)
|
||||||
|
// logging.disable_timestamp from c and applies them to l. The reconfig
|
||||||
|
// surface is discovered via structural type-assertion on l.Handler(), so
|
||||||
|
// foreign handlers silently opt out of whichever capabilities they do not
|
||||||
|
// implement.
|
||||||
|
//
|
||||||
|
// nebula.Main does NOT call this function on your behalf; callers that want
|
||||||
|
// config-driven log level / format / timestamp updates invoke it at
|
||||||
|
// startup and register it as a reload callback themselves. This keeps the
|
||||||
|
// library from mutating an embedder's logger without their say-so.
|
||||||
|
func ApplyConfig(l *slog.Logger, c Config) error {
|
||||||
|
h := l.Handler()
|
||||||
|
|
||||||
|
lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok {
|
||||||
|
ls.SetLevel(lvl)
|
||||||
|
}
|
||||||
|
|
||||||
|
format := strings.ToLower(c.GetString("logging.format", "text"))
|
||||||
|
if fs, ok := h.(interface{ SetFormat(string) error }); ok {
|
||||||
|
if err := fs.SetFormat(format); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok {
|
||||||
|
ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseLevel converts a config-string level name ("trace", "debug", "info",
|
||||||
|
// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and
|
||||||
|
// "panic" are accepted for backwards compatibility with pre-slog configs
|
||||||
|
// and both map to slog.LevelError.
|
||||||
|
func ParseLevel(s string) (slog.Level, error) {
|
||||||
|
switch s {
|
||||||
|
case "trace":
|
||||||
|
return LevelTrace, nil
|
||||||
|
case "debug":
|
||||||
|
return slog.LevelDebug, nil
|
||||||
|
case "info":
|
||||||
|
return slog.LevelInfo, nil
|
||||||
|
case "warn", "warning":
|
||||||
|
return slog.LevelWarn, nil
|
||||||
|
case "error":
|
||||||
|
return slog.LevelError, nil
|
||||||
|
case "fatal", "panic":
|
||||||
|
return slog.LevelError, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("not a valid logging level: %q", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LevelName returns a human-readable name for a slog.Level matching the
|
||||||
|
// strings accepted by ParseLevel.
|
||||||
|
func LevelName(l slog.Level) string {
|
||||||
|
switch {
|
||||||
|
case l <= LevelTrace:
|
||||||
|
return "trace"
|
||||||
|
case l <= slog.LevelDebug:
|
||||||
|
return "debug"
|
||||||
|
case l <= slog.LevelInfo:
|
||||||
|
return "info"
|
||||||
|
case l <= slog.LevelWarn:
|
||||||
|
return "warn"
|
||||||
|
default:
|
||||||
|
return "error"
|
||||||
|
}
|
||||||
|
}
|
||||||
90
logging/logger_bench_test.go
Normal file
90
logging/logger_bench_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkLogger_* compare the handler returned by NewLogger against a
|
||||||
|
// stock slog text handler. The key thing we care about is the per-log
|
||||||
|
// cost on a logger that has been derived via .With(), because that is the
|
||||||
|
// shape subsystems store on their structs (HostInfo.logger(),
|
||||||
|
// lh.l.With("subsystem", ...), etc.) and call from hot paths.
|
||||||
|
|
||||||
|
func BenchmarkLogger_Stock_RootInfo(b *testing.B) {
|
||||||
|
l := slog.New(slog.DiscardHandler)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Nebula_RootInfo(b *testing.B) {
|
||||||
|
l := NewLogger(io.Discard)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) {
|
||||||
|
l := slog.New(slog.DiscardHandler).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) {
|
||||||
|
l := NewLogger(io.Discard).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gated-off-path benchmarks: mimic the typical hot-path shape
|
||||||
|
// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below
|
||||||
|
// the active level. This is the dominant pattern in inside.go/outside.go and
|
||||||
|
// what we pay on every packet.
|
||||||
|
func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) {
|
||||||
|
l := slog.New(slog.DiscardHandler).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
ctx := context.Background()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if l.Enabled(ctx, slog.LevelDebug) {
|
||||||
|
l.Debug("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) {
|
||||||
|
l := NewLogger(io.Discard).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
ctx := context.Background()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if l.Enabled(ctx, slog.LevelDebug) {
|
||||||
|
l.Debug("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
83
main.go
83
main.go
@@ -3,13 +3,13 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|
||||||
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
buildVersion = moduleVersion()
|
buildVersion = moduleVersion()
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logger
|
|
||||||
l.Formatter = &logrus.TextFormatter{
|
|
||||||
FullTimestamp: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print the config if in test, the exit comes later
|
// Print the config if in test, the exit comes later
|
||||||
if configTest {
|
if configTest {
|
||||||
b, err := yaml.Marshal(c.Settings)
|
b, err := yaml.Marshal(c.Settings)
|
||||||
@@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Print the final config
|
// Print the final config
|
||||||
l.Println(string(b))
|
l.Info(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := configLogger(l, c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
|
||||||
err := configLogger(l, c)
|
|
||||||
if err != nil {
|
|
||||||
l.WithError(err).Error("Failed to configure the logger")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
pki, err := NewPKIFromConfig(l, c)
|
pki, err := NewPKIFromConfig(l, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
||||||
@@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
||||||
}
|
}
|
||||||
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
|
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
||||||
}
|
}
|
||||||
@@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err)
|
||||||
sshStart = nil
|
sshStart = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -99,19 +82,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
routines = 1
|
routines = 1
|
||||||
}
|
}
|
||||||
if routines > 1 {
|
if routines > 1 {
|
||||||
l.WithField("routines", routines).Info("Using multiple routines")
|
l.Info("Using multiple routines", "routines", routines)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// deprecated and undocumented
|
// deprecated and undocumented
|
||||||
tunQueues := c.GetInt("tun.routines", 1)
|
tunQueues := c.GetInt("tun.routines", 1)
|
||||||
udpQueues := c.GetInt("listen.routines", 1)
|
udpQueues := c.GetInt("listen.routines", 1)
|
||||||
if tunQueues > udpQueues {
|
routines = max(tunQueues, udpQueues)
|
||||||
routines = tunQueues
|
|
||||||
} else {
|
|
||||||
routines = udpQueues
|
|
||||||
}
|
|
||||||
if routines != 1 {
|
if routines != 1 {
|
||||||
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
|
l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
conntrackCacheTimeout = 1 * time.Second
|
conntrackCacheTimeout = 1 * time.Second
|
||||||
}
|
}
|
||||||
if conntrackCacheTimeout > 0 {
|
if conntrackCacheTimeout > 0 {
|
||||||
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
var tun overlay.Device
|
var tun overlay.Device
|
||||||
@@ -170,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
@@ -219,13 +198,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||||
lightHouse.handshakeTrigger = handshakeManager.trigger
|
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||||
|
|
||||||
serveDns := false
|
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
|
||||||
if c.GetBool("lighthouse.serve_dns", false) {
|
if err != nil {
|
||||||
if c.GetBool("lighthouse.am_lighthouse", false) {
|
l.Warn("Failed to start DNS responder", "error", err)
|
||||||
serveDns = true
|
|
||||||
} else {
|
|
||||||
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
@@ -234,7 +209,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
Outside: udpConns[0],
|
Outside: udpConns[0],
|
||||||
pki: pki,
|
pki: pki,
|
||||||
Firewall: fw,
|
Firewall: fw,
|
||||||
ServeDns: serveDns,
|
DnsServer: ds,
|
||||||
HandshakeManager: handshakeManager,
|
HandshakeManager: handshakeManager,
|
||||||
connectionManager: connManager,
|
connectionManager: connManager,
|
||||||
lightHouse: lightHouse,
|
lightHouse: lightHouse,
|
||||||
@@ -271,7 +246,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
go handshakeManager.Run(ctx)
|
go handshakeManager.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
statsStart, err := startStats(l, c, buildVersion, configTest)
|
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
|
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
|
||||||
}
|
}
|
||||||
@@ -284,23 +259,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
|
|
||||||
attachCommands(l, c, ssh, ifce)
|
attachCommands(l, c, ssh, ifce)
|
||||||
|
|
||||||
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
|
||||||
var dnsStart func()
|
|
||||||
if lightHouse.amLighthouse && serveDns {
|
|
||||||
l.Debugln("Starting dns server")
|
|
||||||
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
ifce,
|
state: StateReady,
|
||||||
l,
|
f: ifce,
|
||||||
ctx,
|
l: l,
|
||||||
cancel,
|
ctx: ctx,
|
||||||
sshStart,
|
cancel: cancel,
|
||||||
statsStart,
|
sshStart: sshStart,
|
||||||
dnsStart,
|
statsStart: stats.Start,
|
||||||
lightHouse.StartUpdateWorker,
|
dnsStart: ds.Start,
|
||||||
connManager.Start,
|
lighthouseStart: lightHouse.StartUpdateWorker,
|
||||||
|
connectionManagerStart: connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
60
noise.go
60
noise.go
@@ -15,14 +15,12 @@ type endianness interface {
|
|||||||
var noiseEndianness endianness = binary.BigEndian
|
var noiseEndianness endianness = binary.BigEndian
|
||||||
|
|
||||||
type NebulaCipherState struct {
|
type NebulaCipherState struct {
|
||||||
c noise.Cipher
|
c cipher.AEAD
|
||||||
//k [32]byte
|
|
||||||
//n uint64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
|
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
|
||||||
return &NebulaCipherState{c: s.Cipher()}
|
x := s.Cipher()
|
||||||
|
return &NebulaCipherState{c: x.(cipher.AEAD)}
|
||||||
}
|
}
|
||||||
|
|
||||||
type cipherAEADDanger interface {
|
type cipherAEADDanger interface {
|
||||||
@@ -40,25 +38,20 @@ type cipherAEADDanger interface {
|
|||||||
// be re-used by callers to minimize garbage collection.
|
// be re-used by callers to minimize garbage collection.
|
||||||
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
|
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||||
if s != nil {
|
if s != nil {
|
||||||
switch ce := s.c.(type) {
|
// TODO: Is this okay now that we have made messageCounter atomic?
|
||||||
case cipherAEADDanger:
|
// Alternative may be to split the counter space into ranges
|
||||||
return ce.EncryptDanger(out, ad, plaintext, n, nb)
|
//if n <= s.n {
|
||||||
default:
|
// return nil, errors.New("CRITICAL: a duplicate counter value was used")
|
||||||
// TODO: Is this okay now that we have made messageCounter atomic?
|
//}
|
||||||
// Alternative may be to split the counter space into ranges
|
//s.n = n
|
||||||
//if n <= s.n {
|
nb[0] = 0
|
||||||
// return nil, errors.New("CRITICAL: a duplicate counter value was used")
|
nb[1] = 0
|
||||||
//}
|
nb[2] = 0
|
||||||
//s.n = n
|
nb[3] = 0
|
||||||
nb[0] = 0
|
noiseEndianness.PutUint64(nb[4:], n)
|
||||||
nb[1] = 0
|
out = s.c.Seal(out, nb, plaintext, ad)
|
||||||
nb[2] = 0
|
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
|
||||||
nb[3] = 0
|
return out, nil
|
||||||
noiseEndianness.PutUint64(nb[4:], n)
|
|
||||||
out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad)
|
|
||||||
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return nil, errors.New("no cipher state available to encrypt")
|
return nil, errors.New("no cipher state available to encrypt")
|
||||||
}
|
}
|
||||||
@@ -66,17 +59,12 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n
|
|||||||
|
|
||||||
func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
|
func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||||
if s != nil {
|
if s != nil {
|
||||||
switch ce := s.c.(type) {
|
nb[0] = 0
|
||||||
case cipherAEADDanger:
|
nb[1] = 0
|
||||||
return ce.DecryptDanger(out, ad, ciphertext, n, nb)
|
nb[2] = 0
|
||||||
default:
|
nb[3] = 0
|
||||||
nb[0] = 0
|
noiseEndianness.PutUint64(nb[4:], n)
|
||||||
nb[1] = 0
|
return s.c.Open(out, nb, ciphertext, ad)
|
||||||
nb[2] = 0
|
|
||||||
nb[3] = 0
|
|
||||||
noiseEndianness.PutUint64(nb[4:], n)
|
|
||||||
return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return []byte{}, nil
|
return []byte{}, nil
|
||||||
}
|
}
|
||||||
@@ -84,7 +72,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64,
|
|||||||
|
|
||||||
func (s *NebulaCipherState) Overhead() int {
|
func (s *NebulaCipherState) Overhead() int {
|
||||||
if s != nil {
|
if s != nil {
|
||||||
return s.c.(cipher.AEAD).Overhead()
|
return s.c.Overhead()
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build !boringcrypto
|
//go:build !boringcrypto
|
||||||
// +build !boringcrypto
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
216
outside.go
216
outside.go
@@ -1,15 +1,16 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
@@ -24,7 +25,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
if len(packet) > 1 {
|
if len(packet) > 1 {
|
||||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
|
f.l.Info("Error while parsing inbound packet",
|
||||||
|
"from", via,
|
||||||
|
"error", err,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -32,8 +37,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
|
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
|
f.l.Debug("Refusing to process double encrypted packet", "from", via)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -87,7 +92,10 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
if !ok {
|
if !ok {
|
||||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
||||||
// its internal mapping. This should never happen.
|
// its internal mapping. This should never happen.
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +116,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
// Find the target HostInfo relay object
|
// Find the target HostInfo relay object
|
||||||
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
|
||||||
|
"relayTo", relay.PeerAddr,
|
||||||
|
"error", err,
|
||||||
|
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +136,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
hostinfo.logger(f.l).Info("Unexpected target relay state",
|
||||||
|
"relayTo", relay.PeerAddr,
|
||||||
|
"relayFrom", hostinfo.vpnAddrs[0],
|
||||||
|
"targetRelayState", targetRelay.State,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -138,9 +154,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
|
||||||
WithField("packet", packet).
|
"error", err,
|
||||||
Error("Failed to decrypt lighthouse packet")
|
"from", via,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,9 +175,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).Error("Failed to decrypt test packet",
|
||||||
WithField("packet", packet).
|
"error", err,
|
||||||
Error("Failed to decrypt test packet")
|
"from", via,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,9 +210,17 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
if !f.handleEncrypted(ci, via, h) {
|
if !f.handleEncrypted(ci, via, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
|
||||||
|
"error", err,
|
||||||
|
"from", via,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("from", via).
|
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
|
||||||
Info("Close tunnel received, tearing down.")
|
|
||||||
|
|
||||||
f.closeTunnel(hostinfo)
|
f.closeTunnel(hostinfo)
|
||||||
return
|
return
|
||||||
@@ -204,9 +232,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
|
||||||
WithField("packet", packet).
|
"error", err,
|
||||||
Error("Failed to decrypt Control packet")
|
"from", via,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,7 +244,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
|
|
||||||
default:
|
default:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,20 +272,27 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
|||||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
|
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
|
||||||
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
|
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
||||||
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
return
|
hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr)
|
||||||
}
|
|
||||||
|
|
||||||
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
|
||||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||||
Info("Host roamed to new udp ip/port.")
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote",
|
||||||
|
"suppressSeconds", RoamingSuppressSeconds,
|
||||||
|
"udpAddr", hostinfo.remote,
|
||||||
|
"newAddr", via.UdpAddr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.",
|
||||||
|
"udpAddr", hostinfo.remote,
|
||||||
|
"newAddr", via.UdpAddr,
|
||||||
|
)
|
||||||
hostinfo.lastRoam = time.Now()
|
hostinfo.lastRoam = time.Now()
|
||||||
hostinfo.lastRoamRemote = hostinfo.remote
|
hostinfo.lastRoamRemote = hostinfo.remote
|
||||||
hostinfo.SetRemote(via.UdpAddr)
|
hostinfo.SetRemote(via.UdpAddr)
|
||||||
@@ -327,13 +366,29 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
proto := layers.IPProtocol(data[protoAt])
|
proto := layers.IPProtocol(data[protoAt])
|
||||||
|
|
||||||
switch proto {
|
switch proto {
|
||||||
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
case layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
||||||
fp.Protocol = uint8(proto)
|
fp.Protocol = uint8(proto)
|
||||||
fp.RemotePort = 0
|
fp.RemotePort = 0
|
||||||
fp.LocalPort = 0
|
fp.LocalPort = 0
|
||||||
fp.Fragment = false
|
fp.Fragment = false
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
case layers.IPProtocolICMPv6:
|
||||||
|
if dataLen < offset+6 {
|
||||||
|
return ErrIPv6PacketTooShort
|
||||||
|
}
|
||||||
|
fp.Protocol = uint8(proto)
|
||||||
|
fp.LocalPort = 0 //incoming vs outgoing doesn't matter for icmpv6
|
||||||
|
icmptype := data[offset+1]
|
||||||
|
switch icmptype {
|
||||||
|
case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply:
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier
|
||||||
|
default:
|
||||||
|
fp.RemotePort = 0
|
||||||
|
}
|
||||||
|
fp.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
case layers.IPProtocolTCP, layers.IPProtocolUDP:
|
case layers.IPProtocolTCP, layers.IPProtocolUDP:
|
||||||
if dataLen < offset+4 {
|
if dataLen < offset+4 {
|
||||||
return ErrIPv6PacketTooShort
|
return ErrIPv6PacketTooShort
|
||||||
@@ -423,34 +478,38 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
|
|
||||||
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
|
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
|
||||||
minLen := ihl
|
minLen := ihl
|
||||||
if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
|
if !fp.Fragment {
|
||||||
minLen += minFwPacketLen
|
if fp.Protocol == firewall.ProtoICMP {
|
||||||
|
minLen += minFwPacketLen + 2
|
||||||
|
} else {
|
||||||
|
minLen += minFwPacketLen
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(data) < minLen {
|
if len(data) < minLen {
|
||||||
return ErrIPv4InvalidHeaderLength
|
return ErrIPv4InvalidHeaderLength
|
||||||
}
|
}
|
||||||
|
|
||||||
// Firewall packets are locally oriented
|
if incoming { // Firewall packets are locally oriented
|
||||||
if incoming {
|
|
||||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
|
||||||
fp.RemotePort = 0
|
|
||||||
fp.LocalPort = 0
|
|
||||||
} else {
|
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
}
|
||||||
fp.RemotePort = 0
|
|
||||||
fp.LocalPort = 0
|
if fp.Fragment {
|
||||||
} else {
|
fp.RemotePort = 0
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
fp.LocalPort = 0
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
} else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
|
||||||
}
|
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
|
||||||
|
fp.LocalPort = 0 //code would be uint16(data[ihl+1])
|
||||||
|
} else if incoming {
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
|
||||||
|
} else {
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -464,8 +523,9 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
||||||
hostinfo.logger(f.l).WithField("header", h).
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
Debugln("dropping out of window packet")
|
hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h)
|
||||||
|
}
|
||||||
return nil, errors.New("out of window packet")
|
return nil, errors.New("out of window packet")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -477,20 +537,23 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(out, true, fwPacket)
|
err = newPacket(out, true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
|
||||||
Warnf("Error while validating inbound packet")
|
"error", err,
|
||||||
|
"packet", out,
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
Debugln("dropping out of window packet")
|
hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket)
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -499,10 +562,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).Debug("dropping inbound packet",
|
||||||
WithField("reason", dropReason).
|
"fwPacket", fwPacket,
|
||||||
Debugln("dropping inbound packet")
|
"reason", dropReason,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -510,7 +574,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
_, err = f.readers[q].Write(out)
|
_, err = f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.Error("Failed to write to tun", "error", err)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -526,35 +590,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
|
|||||||
|
|
||||||
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
||||||
_ = f.outside.WriteTo(b, endpoint)
|
_ = f.outside.WriteTo(b, endpoint)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("index", index).
|
f.l.Debug("Recv error sent",
|
||||||
WithField("udpAddr", endpoint).
|
"index", index,
|
||||||
Debug("Recv error sent")
|
"udpAddr", endpoint,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||||
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
|
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
|
||||||
f.l.WithField("index", h.RemoteIndex).
|
f.l.Debug("Recv error received, ignoring",
|
||||||
WithField("udpAddr", addr).
|
"index", h.RemoteIndex,
|
||||||
Debug("Recv error received, ignoring")
|
"udpAddr", addr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("index", h.RemoteIndex).
|
f.l.Debug("Recv error received",
|
||||||
WithField("udpAddr", addr).
|
"index", h.RemoteIndex,
|
||||||
Debug("Recv error received")
|
"udpAddr", addr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
|
f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
f.l.Info("Someone spoofing recv_errors?",
|
||||||
|
"addr", addr,
|
||||||
|
"hostinfoRemote", hostinfo.remote,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
// next layer, missing length byte
|
// next layer, missing length byte
|
||||||
err = newPacket(buffer.Bytes()[:49], true, p)
|
err = newPacket(buffer.Bytes()[:49], true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
err = nil
|
||||||
|
|
||||||
// A good ICMP packet
|
// A good ICMP packet
|
||||||
ip = layers.IPv6{
|
ip = layers.IPv6{
|
||||||
@@ -165,20 +166,26 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
DstIP: net.IPv6linklocalallnodes,
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
}
|
}
|
||||||
|
|
||||||
icmp := layers.ICMPv6{}
|
icmp := layers.ICMPv6{
|
||||||
|
TypeCode: layers.ICMPv6TypeEchoRequest,
|
||||||
buffer.Clear()
|
Checksum: 0x1234,
|
||||||
err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(buffer.Bytes(), true, p)
|
buffer.Clear()
|
||||||
require.NoError(t, err)
|
require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp))
|
||||||
|
require.Error(t, newPacket(buffer.Bytes(), true, p))
|
||||||
|
|
||||||
|
buffer.Clear()
|
||||||
|
echo := layers.ICMPv6Echo{
|
||||||
|
Identifier: 0xabcd,
|
||||||
|
SeqNumber: 1234,
|
||||||
|
}
|
||||||
|
require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp, &echo))
|
||||||
|
require.NoError(t, newPacket(buffer.Bytes(), true, p))
|
||||||
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
assert.Equal(t, uint16(0), p.RemotePort)
|
assert.Equal(t, uint16(0xabcd), p.RemotePort)
|
||||||
assert.Equal(t, uint16(0), p.LocalPort)
|
assert.Equal(t, uint16(0), p.LocalPort)
|
||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
@@ -574,7 +581,7 @@ func BenchmarkParseV6(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
evilBytes := buffer.Bytes()
|
evilBytes := buffer.Bytes()
|
||||||
for i := 0; i < 200; i++ {
|
for range 200 {
|
||||||
evilBytes = append(evilBytes, hopHeader...)
|
evilBytes = append(evilBytes, hopHeader...)
|
||||||
}
|
}
|
||||||
evilBytes = append(evilBytes, lastHopHeader...)
|
evilBytes = append(evilBytes, lastHopHeader...)
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
package test
|
// Package overlaytest provides fakes of overlay.Device for tests that do
|
||||||
|
// not want to touch a real tun device or route table.
|
||||||
|
package overlaytest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
@@ -8,6 +10,9 @@ import (
|
|||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NoopTun is an overlay.Device that silently discards every read and write.
|
||||||
|
// Useful in tests that need to construct a nebula Interface but do not
|
||||||
|
// exercise the datapath.
|
||||||
type NoopTun struct{}
|
type NoopTun struct{}
|
||||||
|
|
||||||
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
||||||
@@ -2,6 +2,7 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -9,7 +10,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
@@ -48,11 +48,14 @@ func (r Route) String() string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
||||||
routeTree := new(bart.Table[routing.Gateways])
|
routeTree := new(bart.Table[routing.Gateways])
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !allowMTU && r.MTU > 0 {
|
if !allowMTU && r.MTU > 0 {
|
||||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
l.Warn("route MTU is not supported on this platform",
|
||||||
|
"goos", runtime.GOOS,
|
||||||
|
"route", r,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
gateways := r.Via
|
gateways := r.Via
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) {
|
|||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip, err := netip.ParseAddr("1.0.0.2")
|
ip, err := netip.ParseAddr("1.0.0.2")
|
||||||
@@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
|
|||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 3)
|
assert.Len(t, routes, 3)
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip, err := netip.ParseAddr("192.168.86.1")
|
ip, err := netip.ParseAddr("192.168.86.1")
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
@@ -22,9 +22,9 @@ func (e *NameError) Error() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||||
|
|
||||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
switch {
|
switch {
|
||||||
case c.GetBool("tun.disabled", false):
|
case c.GetBool("tun.disabled", false):
|
||||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||||
@@ -36,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
return newTunFromFd(c, l, *fd, vpnNetworks)
|
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -23,10 +23,10 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTun not supported in Android")
|
return nil, fmt.Errorf("newTun not supported in Android")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -14,7 +15,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -30,7 +30,7 @@ type tun struct {
|
|||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
linkAddr *netroute.LinkAddr
|
linkAddr *netroute.LinkAddr
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
out []byte
|
out []byte
|
||||||
@@ -79,7 +79,7 @@ type ifreqAlias6 struct {
|
|||||||
Lifetime addrLifetime
|
Lifetime addrLifetime
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
name := c.GetString("tun.dev", "")
|
name := c.GetString("tun.dev", "")
|
||||||
ifIndex := -1
|
ifIndex := -1
|
||||||
if name != "" && name != "utun" {
|
if name != "" && name != "utun" {
|
||||||
@@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
err := addRoute(r.Cidr, t.linkAddr)
|
err := addRoute(r.Cidr, t.linkAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, unix.EEXIST) {
|
if errors.Is(err, unix.EEXIST) {
|
||||||
t.l.WithField("route", r.Cidr).
|
t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr)
|
||||||
Warnf("unable to add unsafe_route, identical route already exists")
|
|
||||||
} else {
|
} else {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
@@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
|
|
||||||
err := delRoute(r.Cidr, t.linkAddr)
|
err := delRoute(r.Cidr, t.linkAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
@@ -19,10 +20,10 @@ type disabledTun struct {
|
|||||||
// Track these metrics since we don't have the tun device to do it for us
|
// Track these metrics since we don't have the tun device to do it for us
|
||||||
tx metrics.Counter
|
tx metrics.Counter
|
||||||
rx metrics.Counter
|
rx metrics.Counter
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
|
||||||
tun := &disabledTun{
|
tun := &disabledTun{
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
read: make(chan []byte, queueLen),
|
read: make(chan []byte, queueLen),
|
||||||
@@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.tx.Inc(1)
|
t.tx.Inc(1)
|
||||||
if t.l.Level >= logrus.DebugLevel {
|
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
t.l.Debug("Write payload", "raw", prettyPacket(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
return copy(b, r), nil
|
return copy(b, r), nil
|
||||||
@@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
|||||||
select {
|
select {
|
||||||
case t.read <- out:
|
case t.read <- out:
|
||||||
default:
|
default:
|
||||||
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
|
t.l.Debug("tun_disabled: dropped ICMP Echo Reply response")
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
|
|
||||||
// Check for ICMP Echo Request before spending time doing the full parsing
|
// Check for ICMP Echo Request before spending time doing the full parsing
|
||||||
if t.handleICMPEchoRequest(b) {
|
if t.handleICMPEchoRequest(b) {
|
||||||
if t.l.Level >= logrus.DebugLevel {
|
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
|
t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b))
|
||||||
}
|
}
|
||||||
} else if t.l.Level >= logrus.DebugLevel {
|
} else if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
|
t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b))
|
||||||
}
|
}
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|||||||
120
overlay/tun_file_linux_test.go
Normal file
120
overlay/tun_file_linux_test.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
//go:build linux && !android && !e2e_testing
|
||||||
|
// +build linux,!android,!e2e_testing
|
||||||
|
|
||||||
|
package overlay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
|
||||||
|
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
|
||||||
|
func newReadPipe(t *testing.T) int {
|
||||||
|
t.Helper()
|
||||||
|
var fds [2]int
|
||||||
|
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
|
||||||
|
t.Fatalf("pipe2: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = unix.Close(fds[1]) })
|
||||||
|
return fds[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
|
||||||
|
tf, err := newTunFd(newReadPipe(t))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newTunFd: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = tf.Close() })
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := tf.Read(make([]byte, 64))
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Verify Read is actually blocked in poll.
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
t.Fatalf("Read returned before shutdown signal: %v", err)
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tf.wakeForShutdown(); err != nil {
|
||||||
|
t.Fatalf("wakeForShutdown: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if !errors.Is(err, os.ErrClosed) {
|
||||||
|
t.Fatalf("expected os.ErrClosed, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Read did not wake on shutdown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
|
||||||
|
parent, err := newTunFd(newReadPipe(t))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newTunFd: %v", err)
|
||||||
|
}
|
||||||
|
friend, err := parent.newFriend(newReadPipe(t))
|
||||||
|
if err != nil {
|
||||||
|
_ = parent.Close()
|
||||||
|
t.Fatalf("newFriend: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = friend.Close()
|
||||||
|
_ = parent.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
readers := []*tunFile{parent, friend}
|
||||||
|
errs := make([]error, len(readers))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i, r := range readers {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int, r *tunFile) {
|
||||||
|
defer wg.Done()
|
||||||
|
_, errs[i] = r.Read(make([]byte, 64))
|
||||||
|
}(i, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
if err := parent.wakeForShutdown(); err != nil {
|
||||||
|
t.Fatalf("wakeForShutdown: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() { wg.Wait(); close(done) }()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("readers did not wake")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, err := range errs {
|
||||||
|
if !errors.Is(err, os.ErrClosed) {
|
||||||
|
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunFile_Close_Idempotent(t *testing.T) {
|
||||||
|
tf, err := newTunFd(newReadPipe(t))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newTunFd: %v", err)
|
||||||
|
}
|
||||||
|
if err := tf.Close(); err != nil {
|
||||||
|
t.Fatalf("first Close: %v", err)
|
||||||
|
}
|
||||||
|
if err := tf.Close(); err != nil {
|
||||||
|
t.Fatalf("second Close should be a no-op, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,15 +9,18 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
@@ -92,133 +95,232 @@ type tun struct {
|
|||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
linkAddr *netroute.LinkAddr
|
linkAddr *netroute.LinkAddr
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
devFd int
|
|
||||||
|
fd int
|
||||||
|
shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls
|
||||||
|
shutdownW int // write end of the shutdown pipe; closing this signals shutdown to any blocked reader/writer
|
||||||
|
readPoll [2]unix.PollFd
|
||||||
|
writePoll [2]unix.PollFd
|
||||||
|
closed atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// blockOnRead waits until the tun fd is readable or shutdown has been signaled.
|
||||||
|
// Returns os.ErrClosed if Close was called.
|
||||||
|
func (t *tun) blockOnRead() error {
|
||||||
|
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
_, err = unix.Poll(t.readPoll[:], -1)
|
||||||
|
if err != unix.EINTR {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tunEvents := t.readPoll[0].Revents
|
||||||
|
shutdownEvents := t.readPoll[1].Revents
|
||||||
|
t.readPoll[0].Revents = 0
|
||||||
|
t.readPoll[1].Revents = 0
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
if tunEvents&problemFlags != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) blockOnWrite() error {
|
||||||
|
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
_, err = unix.Poll(t.writePoll[:], -1)
|
||||||
|
if err != unix.EINTR {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tunEvents := t.writePoll[0].Revents
|
||||||
|
shutdownEvents := t.writePoll[1].Revents
|
||||||
|
t.writePoll[0].Revents = 0
|
||||||
|
t.writePoll[1].Revents = 0
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
if tunEvents&problemFlags != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
// first 4 bytes is protocol family, in network byte order
|
||||||
head := make([]byte, 4)
|
var head [4]byte
|
||||||
|
iovecs := [2]syscall.Iovec{
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
{&head[0], 4},
|
||||||
{&to[0], uint64(len(to))},
|
{&to[0], uint64(len(to))},
|
||||||
}
|
}
|
||||||
|
for {
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2)
|
||||||
|
if errno == 0 {
|
||||||
var err error
|
bytesRead := int(n)
|
||||||
if errno != 0 {
|
if bytesRead < 4 {
|
||||||
err = syscall.Errno(errno)
|
return 0, nil
|
||||||
} else {
|
}
|
||||||
err = nil
|
return bytesRead - 4, nil
|
||||||
}
|
}
|
||||||
// fix bytes read number to exclude header
|
switch errno {
|
||||||
bytesRead := int(n)
|
case unix.EAGAIN:
|
||||||
if bytesRead < 0 {
|
if err := t.blockOnRead(); err != nil {
|
||||||
return bytesRead, err
|
return 0, err
|
||||||
} else if bytesRead < 4 {
|
}
|
||||||
return 0, err
|
case unix.EINTR:
|
||||||
} else {
|
// retry
|
||||||
return bytesRead - 4, err
|
case unix.EBADF:
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
default:
|
||||||
|
return 0, errno
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
// Write is only valid for single threaded use
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(from) <= 1 {
|
if len(from) <= 1 {
|
||||||
return 0, syscall.EIO
|
return 0, syscall.EIO
|
||||||
}
|
}
|
||||||
|
|
||||||
ipVer := from[0] >> 4
|
ipVer := from[0] >> 4
|
||||||
var head []byte
|
var head [4]byte
|
||||||
// first 4 bytes is protocol family, in network byte order
|
// first 4 bytes is protocol family, in network byte order
|
||||||
if ipVer == 4 {
|
switch ipVer {
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET}
|
case 4:
|
||||||
} else if ipVer == 6 {
|
head[3] = syscall.AF_INET
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET6}
|
case 6:
|
||||||
} else {
|
head[3] = syscall.AF_INET6
|
||||||
|
default:
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||||
}
|
}
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
|
iovecs := [2]syscall.Iovec{
|
||||||
{&head[0], 4},
|
{&head[0], 4},
|
||||||
{&from[0], uint64(len(from))},
|
{&from[0], uint64(len(from))},
|
||||||
}
|
}
|
||||||
|
for {
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2)
|
||||||
|
if errno == 0 {
|
||||||
var err error
|
return int(n) - 4, nil
|
||||||
if errno != 0 {
|
}
|
||||||
err = syscall.Errno(errno)
|
switch errno {
|
||||||
} else {
|
case unix.EAGAIN:
|
||||||
err = nil
|
if err := t.blockOnWrite(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
case unix.EINTR:
|
||||||
|
// retry
|
||||||
|
case unix.EBADF:
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
default:
|
||||||
|
return 0, errno
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return int(n) - 4, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *tun) Close() error {
|
||||||
if t.devFd >= 0 {
|
if t.closed.Swap(true) {
|
||||||
err := syscall.Close(t.devFd)
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Closing the write end of the shutdown pipe causes any blocked Poll to
|
||||||
|
// return with POLLHUP on the shutdown fd, so readers/writers wake up and
|
||||||
|
// exit with os.ErrClosed.
|
||||||
|
if t.shutdownW >= 0 {
|
||||||
|
_ = unix.Close(t.shutdownW)
|
||||||
|
t.shutdownW = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.fd >= 0 {
|
||||||
|
if err := unix.Close(t.fd); err != nil {
|
||||||
|
t.l.Error("Error closing device", "error", err)
|
||||||
|
}
|
||||||
|
t.fd = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.shutdownR >= 0 {
|
||||||
|
_ = unix.Close(t.shutdownR)
|
||||||
|
t.shutdownR = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
c := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
||||||
|
defer close(c)
|
||||||
|
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
|
if err == nil {
|
||||||
|
defer syscall.Close(s)
|
||||||
|
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||||
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).Error("Error closing device")
|
t.l.Error("Error destroying tunnel", "error", err)
|
||||||
}
|
}
|
||||||
t.devFd = -1
|
}()
|
||||||
|
|
||||||
c := make(chan struct{})
|
// wait up to 1 second so we start blocking at the ioctl
|
||||||
go func() {
|
select {
|
||||||
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
case <-c:
|
||||||
defer close(c)
|
case <-time.After(1 * time.Second):
|
||||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err == nil {
|
|
||||||
defer syscall.Close(s)
|
|
||||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
|
||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Error destroying tunnel")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// wait up to 1 second so we start blocking at the ioctl
|
|
||||||
select {
|
|
||||||
case <-c:
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open existing tun device
|
// Try to open existing tun device
|
||||||
var fd int
|
var fd int
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName != "" {
|
if deviceName != "" {
|
||||||
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
fd, err = unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||||
// If the device doesn't already exist, request a new one and rename it
|
// If the device doesn't already exist, request a new one and rename it
|
||||||
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
fd, err = unix.Open("/dev/tun", os.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = unix.SetNonblock(fd, true); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, fmt.Errorf("failed to set tun device as nonblocking: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown pipe lets Close wake any reader/writer blocked in Poll.
|
||||||
|
var pipeFds [2]int
|
||||||
|
if err = unix.Pipe2(pipeFds[:], unix.O_CLOEXEC|unix.O_NONBLOCK); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, fmt.Errorf("failed to create shutdown pipe: %w", err)
|
||||||
|
}
|
||||||
|
shutdownR, shutdownW := pipeFds[0], pipeFds[1]
|
||||||
|
|
||||||
|
closeOnErr := true
|
||||||
|
defer func() {
|
||||||
|
if closeOnErr {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
_ = unix.Close(shutdownR)
|
||||||
|
_ = unix.Close(shutdownW)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Read the name of the interface
|
// Read the name of the interface
|
||||||
var name [16]byte
|
var name [16]byte
|
||||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||||
@@ -237,7 +339,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ctrlErr != nil {
|
if ctrlErr != nil {
|
||||||
return nil, err
|
return nil, ctrlErr
|
||||||
}
|
}
|
||||||
|
|
||||||
ifName := string(bytes.TrimRight(name[:], "\x00"))
|
ifName := string(bytes.TrimRight(name[:], "\x00"))
|
||||||
@@ -253,8 +355,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
defer syscall.Close(s)
|
defer syscall.Close(s)
|
||||||
|
|
||||||
fd := uintptr(s)
|
|
||||||
|
|
||||||
var fromName [16]byte
|
var fromName [16]byte
|
||||||
var toName [16]byte
|
var toName [16]byte
|
||||||
copy(fromName[:], ifName)
|
copy(fromName[:], ifName)
|
||||||
@@ -266,7 +366,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set the device name
|
// Set the device name
|
||||||
_ = ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
|
_ = ioctl(uintptr(s), syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
@@ -274,13 +374,24 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
devFd: fd,
|
fd: fd,
|
||||||
|
shutdownR: shutdownR,
|
||||||
|
shutdownW: shutdownW,
|
||||||
|
readPoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLIN},
|
||||||
|
{Fd: int32(shutdownR), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
writePoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||||
|
{Fd: int32(shutdownR), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
closeOnErr = false
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
err := t.reload(c, false)
|
err := t.reload(c, false)
|
||||||
@@ -475,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -490,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
|
|
||||||
err := delRoute(r.Cidr, t.linkAddr)
|
err := delRoute(r.Cidr, t.linkAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -14,7 +15,6 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -25,14 +25,14 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||||
t := &tun{
|
t := &tun{
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
|
|||||||
@@ -4,8 +4,10 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -16,7 +18,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -24,9 +25,175 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
|
||||||
|
// A shared eventfd allows Close to wake all readers blocked in poll.
|
||||||
|
type tunFile struct {
|
||||||
|
fd int
|
||||||
|
shutdownFd int
|
||||||
|
lastOne bool
|
||||||
|
readPoll [2]unix.PollFd
|
||||||
|
writePoll [2]unix.PollFd
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
|
||||||
|
func (r *tunFile) newFriend(fd int) (*tunFile, error) {
|
||||||
|
if err := unix.SetNonblock(fd, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||||
|
}
|
||||||
|
return &tunFile{
|
||||||
|
fd: fd,
|
||||||
|
shutdownFd: r.shutdownFd,
|
||||||
|
readPoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLIN},
|
||||||
|
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
writePoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||||
|
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFd(fd int) (*tunFile, error) {
|
||||||
|
if err := unix.SetNonblock(fd, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create eventfd: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &tunFile{
|
||||||
|
fd: fd,
|
||||||
|
shutdownFd: shutdownFd,
|
||||||
|
lastOne: true,
|
||||||
|
readPoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLIN},
|
||||||
|
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
writePoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||||
|
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) blockOnRead() error {
|
||||||
|
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
_, err = unix.Poll(r.readPoll[:], -1)
|
||||||
|
if err != unix.EINTR {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//always reset these!
|
||||||
|
tunEvents := r.readPoll[0].Revents
|
||||||
|
shutdownEvents := r.readPoll[1].Revents
|
||||||
|
r.readPoll[0].Revents = 0
|
||||||
|
r.readPoll[1].Revents = 0
|
||||||
|
//do the err check before trusting the potentially bogus bits we just got
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
} else if tunEvents&problemFlags != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) blockOnWrite() error {
|
||||||
|
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
_, err = unix.Poll(r.writePoll[:], -1)
|
||||||
|
if err != unix.EINTR {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//always reset these!
|
||||||
|
tunEvents := r.writePoll[0].Revents
|
||||||
|
shutdownEvents := r.writePoll[1].Revents
|
||||||
|
r.writePoll[0].Revents = 0
|
||||||
|
r.writePoll[1].Revents = 0
|
||||||
|
//do the err check before trusting the potentially bogus bits we just got
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
} else if tunEvents&problemFlags != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) Read(buf []byte) (int, error) {
|
||||||
|
for {
|
||||||
|
if n, err := unix.Read(r.fd, buf); err == nil {
|
||||||
|
return n, nil
|
||||||
|
} else if err == unix.EAGAIN {
|
||||||
|
if err = r.blockOnRead(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else if err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
} else if err == unix.EBADF {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
} else {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) Write(buf []byte) (int, error) {
|
||||||
|
for {
|
||||||
|
if n, err := unix.Write(r.fd, buf); err == nil {
|
||||||
|
return n, nil
|
||||||
|
} else if err == unix.EAGAIN {
|
||||||
|
if err = r.blockOnWrite(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else if err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
} else if err == unix.EBADF {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
} else {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) wakeForShutdown() error {
|
||||||
|
var buf [8]byte
|
||||||
|
binary.NativeEndian.PutUint64(buf[:], 1)
|
||||||
|
_, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) Close() error {
|
||||||
|
if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.closed = true
|
||||||
|
if r.lastOne {
|
||||||
|
_ = unix.Close(r.shutdownFd)
|
||||||
|
}
|
||||||
|
return unix.Close(r.fd)
|
||||||
|
}
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
*tunFile
|
||||||
fd int
|
readers []*tunFile
|
||||||
|
closeLock sync.Mutex
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MaxMTU int
|
MaxMTU int
|
||||||
@@ -46,7 +213,7 @@ type tun struct {
|
|||||||
routesFromSystem map[netip.Prefix]routing.Gateways
|
routesFromSystem map[netip.Prefix]routing.Gateways
|
||||||
routesFromSystemLock sync.Mutex
|
routesFromSystemLock sync.Mutex
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
@@ -71,10 +238,8 @@ type ifreqQLEN struct {
|
|||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -84,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||||
@@ -115,6 +280,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
nameStr := c.GetString("tun.dev", "")
|
nameStr := c.GetString("tun.dev", "")
|
||||||
copy(req.Name[:], nameStr)
|
copy(req.Name[:], nameStr)
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
return nil, &NameError{
|
return nil, &NameError{
|
||||||
Name: nameStr,
|
Name: nameStr,
|
||||||
Underlying: err,
|
Underlying: err,
|
||||||
@@ -122,8 +288,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
}
|
}
|
||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
t, err := newTunGeneric(c, l, fd, vpnNetworks)
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -133,10 +298,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
||||||
|
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
|
tfd, err := newTunFd(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
tunFile: tfd,
|
||||||
fd: int(file.Fd()),
|
readers: []*tunFile{tfd},
|
||||||
|
closeLock: sync.Mutex{},
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
@@ -145,8 +317,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
|
|||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := t.reload(c, true)
|
if err = t.reload(c, true); err != nil {
|
||||||
if err != nil {
|
_ = t.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
if !initial {
|
if !initial {
|
||||||
if oldMaxMTU != newMaxMTU {
|
if oldMaxMTU != newMaxMTU {
|
||||||
t.setMTU()
|
t.setMTU()
|
||||||
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU)
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldDefaultMTU != newDefaultMTU {
|
if oldDefaultMTU != newDefaultMTU {
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
err := t.setDefaultRoute(t.vpnNetworks[i])
|
err := t.setDefaultRoute(t.vpnNetworks[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.Warn(err)
|
t.l.Warn(err.Error())
|
||||||
} else {
|
} else {
|
||||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -239,6 +411,9 @@ func (t *tun) SupportsMultiqueue() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
t.closeLock.Lock()
|
||||||
|
defer t.closeLock.Unlock()
|
||||||
|
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -248,12 +423,19 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||||
copy(req.Name[:], t.Device)
|
copy(req.Name[:], t.Device)
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
out, err := t.tunFile.newFriend(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return file, nil
|
t.readers = append(t.readers, out)
|
||||||
|
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
@@ -261,29 +443,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
|
||||||
var nn int
|
|
||||||
maximum := len(b)
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, err := unix.Write(t.fd, b[nn:maximum])
|
|
||||||
if n > 0 {
|
|
||||||
nn += n
|
|
||||||
}
|
|
||||||
if nn == len(b) {
|
|
||||||
return nn, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nn, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if n == 0 {
|
|
||||||
return nn, io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
for i, c := range t.Device {
|
for i, c := range t.Device {
|
||||||
o[i] = byte(c)
|
o[i] = byte(c)
|
||||||
@@ -333,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error {
|
|||||||
}
|
}
|
||||||
err = netlink.AddrDel(link, &al[i])
|
err = netlink.AddrDel(link, &al[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).Error("failed to remove address from tun address list")
|
t.l.Error("failed to remove address from tun address list", "error", err)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
t.l.Info("removed address not listed in cert(s)", "removed", al[i].String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,12 +538,12 @@ func (t *tun) Activate() error {
|
|||||||
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
t.l.Error("Failed to set tun tx queue length", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const modeNone = 1
|
const modeNone = 1
|
||||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
t.l.Warn("Failed to disable link local address generation", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = t.addIPs(link); err != nil {
|
if err = t.addIPs(link); err != nil {
|
||||||
@@ -423,7 +582,7 @@ func (t *tun) setMTU() {
|
|||||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
|
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
|
||||||
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||||
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
||||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
t.l.Error("Failed to set tun mtu", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -446,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
err := netlink.RouteReplace(&nr)
|
err := netlink.RouteReplace(&nr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
|
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
|
||||||
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
|
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
@@ -454,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
|
t.l.Warn("Failed to set default route MTU, retrying",
|
||||||
|
"error", err,
|
||||||
|
"cidr", cidr,
|
||||||
|
"mtu", t.DefaultMTU,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -499,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -531,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) {
|
|||||||
|
|
||||||
err := netlink.RouteDel(&nr)
|
err := netlink.RouteDel(&nr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -562,11 +725,11 @@ func (t *tun) watchRoutes() {
|
|||||||
netlinkOptions := netlink.RouteSubscribeOptions{
|
netlinkOptions := netlink.RouteSubscribeOptions{
|
||||||
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
||||||
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
||||||
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) },
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
|
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
|
||||||
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
|
t.l.Error("failed to subscribe to system route changes", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -608,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|||||||
|
|
||||||
link, err := netlink.LinkByName(t.Device)
|
link, err := netlink.LinkByName(t.Device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
|
t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device)
|
||||||
return gateways
|
return gateways
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -620,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|||||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||||
} else {
|
} else {
|
||||||
// Gateway isn't in our overlay network, ignore
|
// Gateway isn't in our overlay network, ignore
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -636,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|||||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||||
} else {
|
} else {
|
||||||
// Gateway isn't in our overlay network, ignore
|
// Gateway isn't in our overlay network, ignore
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -671,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
gateways := t.getGatewaysFromRoute(&r.Route)
|
gateways := t.getGatewaysFromRoute(&r.Route)
|
||||||
if len(gateways) == 0 {
|
if len(gateways) == 0 {
|
||||||
// No gateways relevant to our network, no routing changes required.
|
// No gateways relevant to our network, no routing changes required.
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
t.l.Debug("Ignoring route update, no gateways", "route", r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Dst == nil {
|
if r.Dst == nil {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
|
t.l.Debug("Ignoring route update, no destination address", "route", r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
t.l.Debug("Ignoring route update, invalid destination address", "route", r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -693,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
|
|
||||||
t.routesFromSystemLock.Lock()
|
t.routesFromSystemLock.Lock()
|
||||||
if r.Type == unix.RTM_NEWROUTE {
|
if r.Type == unix.RTM_NEWROUTE {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
t.l.Info("Adding route", "destination", dst, "via", gateways)
|
||||||
t.routesFromSystem[dst] = gateways
|
t.routesFromSystem[dst] = gateways
|
||||||
newTree.Insert(dst, gateways)
|
newTree.Insert(dst, gateways)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
t.l.Info("Removing route", "destination", dst, "via", gateways)
|
||||||
delete(t.routesFromSystem, dst)
|
delete(t.routesFromSystem, dst)
|
||||||
newTree.Delete(dst)
|
newTree.Delete(dst)
|
||||||
}
|
}
|
||||||
@@ -707,18 +870,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *tun) Close() error {
|
||||||
|
t.closeLock.Lock()
|
||||||
|
defer t.closeLock.Unlock()
|
||||||
|
|
||||||
if t.routeChan != nil {
|
if t.routeChan != nil {
|
||||||
close(t.routeChan)
|
close(t.routeChan)
|
||||||
|
t.routeChan = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.ReadWriteCloser != nil {
|
// Signal all readers blocked in poll to wake up and exit
|
||||||
_ = t.ReadWriteCloser.Close()
|
_ = t.tunFile.wakeForShutdown()
|
||||||
}
|
|
||||||
|
|
||||||
if t.ioctlFd > 0 {
|
if t.ioctlFd > 0 {
|
||||||
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
_ = unix.Close(int(t.ioctlFd))
|
||||||
t.ioctlFd = 0
|
t.ioctlFd = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
for i := range t.readers {
|
||||||
|
if i == 0 {
|
||||||
|
continue //we want to close the zeroth reader last
|
||||||
|
}
|
||||||
|
err := t.readers[i].Close()
|
||||||
|
if err != nil {
|
||||||
|
t.l.Error("error closing tun reader", "reader", i, "error", err)
|
||||||
|
} else {
|
||||||
|
t.l.Info("closed tun reader", "reader", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//this is t.readers[0] too
|
||||||
|
err := t.tunFile.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.l.Error("error closing tun reader", "reader", 0, "error", err)
|
||||||
|
} else {
|
||||||
|
t.l.Info("closed tun reader", "reader", 0)
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -15,7 +16,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -63,18 +63,18 @@ type tun struct {
|
|||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
f *os.File
|
f *os.File
|
||||||
fd int
|
fd int
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
// Try to open tun device
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
@@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
err = unix.SetNonblock(fd, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
l.Warn("Failed to set the tun device as nonblocking", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
@@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
|
|
||||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -15,7 +16,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -54,7 +54,7 @@ type tun struct {
|
|||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
f *os.File
|
f *os.File
|
||||||
fd int
|
fd int
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
@@ -63,11 +63,11 @@ type tun struct {
|
|||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
// Try to open tun device
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
@@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
err = unix.SetNonblock(fd, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
l.Warn("Failed to set the tun device as nonblocking", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
@@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
|
|
||||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -4,14 +4,15 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
@@ -21,14 +22,14 @@ type TestTun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *bart.Table[routing.Gateways]
|
routeTree *bart.Table[routing.Gateways]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
rxPackets chan []byte // Packets to receive into nebula
|
rxPackets chan []byte // Packets to receive into nebula
|
||||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||||
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
|
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -49,7 +50,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,8 +62,8 @@ func (t *TestTun) Send(packet []byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.l.Level >= logrus.DebugLevel {
|
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
|
t.l.Debug("Tun receiving injected packet", "dataLen", len(packet))
|
||||||
}
|
}
|
||||||
t.rxPackets <- packet
|
t.rxPackets <- packet
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"crypto"
|
"crypto"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -16,7 +17,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -33,16 +33,16 @@ type winTun struct {
|
|||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
tun *wintun.NativeTun
|
tun *wintun.NativeTun
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||||
err := checkWinTunExists()
|
err := checkWinTunExists()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
||||||
@@ -71,7 +71,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
||||||
// Trying a second time resolves the issue.
|
// Trying a second time resolves the issue.
|
||||||
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
l.Debug("Failed to create wintun device, retrying", "error", err)
|
||||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &NameError{
|
return nil, &NameError{
|
||||||
@@ -170,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !foundDefault4 {
|
if !foundDefault4 {
|
||||||
@@ -208,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
|||||||
// See comment on luid.AddRoute
|
// See comment on luid.AddRoute
|
||||||
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,14 +2,14 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
return NewUserDevice(vpnNetworks)
|
return NewUserDevice(vpnNetworks)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
33
pki.go
33
pki.go
@@ -5,6 +5,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -14,7 +16,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -23,7 +24,7 @@ import (
|
|||||||
type PKI struct {
|
type PKI struct {
|
||||||
cs atomic.Pointer[CertState]
|
cs atomic.Pointer[CertState]
|
||||||
caPool atomic.Pointer[cert.CAPool]
|
caPool atomic.Pointer[cert.CAPool]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type CertState struct {
|
type CertState struct {
|
||||||
@@ -45,7 +46,7 @@ type CertState struct {
|
|||||||
myVpnBroadcastAddrsTable *bart.Lite
|
myVpnBroadcastAddrsTable *bart.Lite
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) {
|
||||||
pki := &PKI{l: l}
|
pki := &PKI{l: l}
|
||||||
err := pki.reload(c, true)
|
err := pki.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -181,9 +182,9 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
p.cs.Store(newState)
|
p.cs.Store(newState)
|
||||||
|
|
||||||
if initial {
|
if initial {
|
||||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
p.l.Debug("Client nebula certificate(s)", "cert", newState)
|
||||||
} else {
|
} else {
|
||||||
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
|
p.l.Info("Client certificate(s) refreshed from disk", "cert", newState)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -195,7 +196,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.caPool.Store(caPool)
|
p.caPool.Store(caPool)
|
||||||
p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -486,32 +487,32 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
|
|||||||
return c, b, nil
|
return c, b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) {
|
||||||
var rawCA []byte
|
|
||||||
var err error
|
|
||||||
|
|
||||||
caPathOrPEM := c.GetString("pki.ca", "")
|
caPathOrPEM := c.GetString("pki.ca", "")
|
||||||
if caPathOrPEM == "" {
|
if caPathOrPEM == "" {
|
||||||
return nil, errors.New("no pki.ca path or PEM data provided")
|
return nil, errors.New("no pki.ca path or PEM data provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(caPathOrPEM, "-----BEGIN") {
|
var caReader io.ReadCloser
|
||||||
rawCA = []byte(caPathOrPEM)
|
var err error
|
||||||
|
|
||||||
|
if strings.Contains(caPathOrPEM, "-----BEGIN") {
|
||||||
|
caReader = io.NopCloser(strings.NewReader(caPathOrPEM))
|
||||||
} else {
|
} else {
|
||||||
rawCA, err = os.ReadFile(caPathOrPEM)
|
caReader, err = os.Open(caPathOrPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
|
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
defer caReader.Close()
|
||||||
|
|
||||||
caPool, err := cert.NewCAPoolFromPEM(rawCA)
|
caPool, err := cert.NewCAPoolFromPEMReader(caReader)
|
||||||
if errors.Is(err, cert.ErrExpired) {
|
if errors.Is(err, cert.ErrExpired) {
|
||||||
var expired int
|
var expired int
|
||||||
for _, crt := range caPool.CAs {
|
for _, crt := range caPool.CAs {
|
||||||
if crt.Certificate.Expired(time.Now()) {
|
if crt.Certificate.Expired(time.Now()) {
|
||||||
expired++
|
expired++
|
||||||
l.WithField("cert", crt).Warn("expired certificate present in CA pool")
|
l.Warn("expired certificate present in CA pool", "cert", crt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -529,7 +530,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
|||||||
caPool.BlocklistFingerprint(fp)
|
caPool.BlocklistFingerprint(fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
|
l.Info("Blocklisted certificates", "fingerprintCount", len(bl))
|
||||||
}
|
}
|
||||||
|
|
||||||
return caPool, nil
|
return caPool, nil
|
||||||
|
|||||||
121
pki_hup_benchmark_test.go
Normal file
121
pki_hup_benchmark_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
cert_test "github.com/slackhq/nebula/cert_test"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkReloadConfigWithCAs(b *testing.B) {
|
||||||
|
prevProcs := runtime.GOMAXPROCS(1)
|
||||||
|
b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) })
|
||||||
|
|
||||||
|
for _, size := range []int{100, 250, 500, 1000, 5000} {
|
||||||
|
b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
dir := b.TempDir()
|
||||||
|
|
||||||
|
ca, caKey, caBundle := buildCABundle(b, size)
|
||||||
|
caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle)
|
||||||
|
|
||||||
|
configBody := fmt.Sprintf(`pki:
|
||||||
|
ca: %s
|
||||||
|
cert: %s
|
||||||
|
key: %s
|
||||||
|
`, caPath, certPath, keyPath)
|
||||||
|
|
||||||
|
configPath := filepath.Join(dir, "config.yml")
|
||||||
|
require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600))
|
||||||
|
|
||||||
|
c := config.NewC(l)
|
||||||
|
require.NoError(b, c.Load(dir))
|
||||||
|
|
||||||
|
_, err := NewPKIFromConfig(test.NewLogger(), c)
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
c.ReloadConfig()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) {
|
||||||
|
b.Helper()
|
||||||
|
require.GreaterOrEqual(b, count, 1)
|
||||||
|
|
||||||
|
before := time.Now().Add(-24 * time.Hour)
|
||||||
|
after := time.Now().Add(24 * time.Hour)
|
||||||
|
|
||||||
|
ca, _, caKey, pem := cert_test.NewTestCaCert(
|
||||||
|
cert.Version2,
|
||||||
|
cert.Curve_CURVE25519,
|
||||||
|
before,
|
||||||
|
after,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
buf := bytes.NewBuffer(pem)
|
||||||
|
buf.Write([]byte("\n# a comment!\n"))
|
||||||
|
|
||||||
|
for i := 1; i < count; i++ {
|
||||||
|
_, _, _, extraPEM := cert_test.NewTestCaCert(
|
||||||
|
cert.Version2,
|
||||||
|
cert.Curve_CURVE25519,
|
||||||
|
time.Now(),
|
||||||
|
time.Now().Add(time.Hour),
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
buf.Write([]byte("\n# a comment!\n"))
|
||||||
|
buf.Write(extraPEM)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ca, caKey, buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}
|
||||||
|
|
||||||
|
_, _, keyPEM, certPEM := cert_test.NewTestCert(
|
||||||
|
cert.Version2,
|
||||||
|
cert.Curve_CURVE25519,
|
||||||
|
ca,
|
||||||
|
caKey,
|
||||||
|
"reload-benchmark",
|
||||||
|
time.Now(),
|
||||||
|
time.Now().Add(time.Hour),
|
||||||
|
networks,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
caPath := filepath.Join(dir, "ca.pem")
|
||||||
|
certPath := filepath.Join(dir, "cert.pem")
|
||||||
|
keyPath := filepath.Join(dir, "key.pem")
|
||||||
|
|
||||||
|
require.NoError(b, os.WriteFile(caPath, caBundle, 0o600))
|
||||||
|
require.NoError(b, os.WriteFile(certPath, certPEM, 0o600))
|
||||||
|
require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600))
|
||||||
|
|
||||||
|
return caPath, certPath, keyPath
|
||||||
|
}
|
||||||
14
punchy.go
14
punchy.go
@@ -1,10 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,10 +14,10 @@ type Punchy struct {
|
|||||||
delay atomic.Int64
|
delay atomic.Int64
|
||||||
respondDelay atomic.Int64
|
respondDelay atomic.Int64
|
||||||
punchEverything atomic.Bool
|
punchEverything atomic.Bool
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
|
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
|
||||||
p := &Punchy{l: l}
|
p := &Punchy{l: l}
|
||||||
|
|
||||||
p.reload(c, true)
|
p.reload(c, true)
|
||||||
@@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
|||||||
p.respond.Store(yes)
|
p.respond.Store(yes)
|
||||||
|
|
||||||
if !initial {
|
if !initial {
|
||||||
p.l.Infof("punchy.respond changed to %v", p.GetRespond())
|
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
|||||||
if initial || c.HasChanged("punchy.delay") {
|
if initial || c.HasChanged("punchy.delay") {
|
||||||
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
|
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
|
||||||
if !initial {
|
if !initial {
|
||||||
p.l.Infof("punchy.delay changed to %s", p.GetDelay())
|
p.l.Info("punchy.delay changed", "delay", p.GetDelay())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("punchy.target_all_remotes") {
|
if initial || c.HasChanged("punchy.target_all_remotes") {
|
||||||
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
|
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
|
||||||
if !initial {
|
if !initial {
|
||||||
p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed")
|
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("punchy.respond_delay") {
|
if initial || c.HasChanged("punchy.respond_delay") {
|
||||||
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
|
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
|
||||||
if !initial {
|
if !initial {
|
||||||
p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay())
|
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
173
punchy_test.go
173
punchy_test.go
@@ -1,6 +1,8 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
|||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
|
|
||||||
// Test defaults
|
// Test defaults
|
||||||
p := NewPunchyFromConfig(l, c)
|
p := NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.False(t, p.GetPunch())
|
assert.False(t, p.GetPunch())
|
||||||
assert.False(t, p.GetRespond())
|
assert.False(t, p.GetRespond())
|
||||||
assert.Equal(t, time.Second, p.GetDelay())
|
assert.Equal(t, time.Second, p.GetDelay())
|
||||||
@@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// punchy deprecation
|
// punchy deprecation
|
||||||
c.Settings["punchy"] = true
|
c.Settings["punchy"] = true
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.True(t, p.GetPunch())
|
assert.True(t, p.GetPunch())
|
||||||
|
|
||||||
// punchy.punch
|
// punchy.punch
|
||||||
c.Settings["punchy"] = map[string]any{"punch": true}
|
c.Settings["punchy"] = map[string]any{"punch": true}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.True(t, p.GetPunch())
|
assert.True(t, p.GetPunch())
|
||||||
|
|
||||||
// punch_back deprecation
|
// punch_back deprecation
|
||||||
c.Settings["punch_back"] = true
|
c.Settings["punch_back"] = true
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.True(t, p.GetRespond())
|
assert.True(t, p.GetRespond())
|
||||||
|
|
||||||
// punchy.respond
|
// punchy.respond
|
||||||
c.Settings["punchy"] = map[string]any{"respond": true}
|
c.Settings["punchy"] = map[string]any{"respond": true}
|
||||||
c.Settings["punch_back"] = false
|
c.Settings["punch_back"] = false
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.True(t, p.GetRespond())
|
assert.True(t, p.GetRespond())
|
||||||
|
|
||||||
// punchy.delay
|
// punchy.delay
|
||||||
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.Equal(t, time.Minute, p.GetDelay())
|
assert.Equal(t, time.Minute, p.GetDelay())
|
||||||
|
|
||||||
// punchy.respond_delay
|
// punchy.respond_delay
|
||||||
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +64,7 @@ punchy:
|
|||||||
delay: 1m
|
delay: 1m
|
||||||
respond: false
|
respond: false
|
||||||
`))
|
`))
|
||||||
p := NewPunchyFromConfig(l, c)
|
p := NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.Equal(t, delay, p.GetDelay())
|
assert.Equal(t, delay, p.GetDelay())
|
||||||
assert.False(t, p.GetRespond())
|
assert.False(t, p.GetRespond())
|
||||||
|
|
||||||
@@ -76,3 +78,158 @@ punchy:
|
|||||||
assert.Equal(t, newDelay, p.GetDelay())
|
assert.Equal(t, newDelay, p.GetDelay())
|
||||||
assert.True(t, p.GetRespond())
|
assert.True(t, p.GetRespond())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The tests below pin the shape of each log line Punchy produces so changes
|
||||||
|
// cannot silently break whatever operators are grepping for. The assertions
|
||||||
|
// are on the structured message + attrs (e.g. "punchy.respond changed" with
|
||||||
|
// a respond=true field) rather than a formatted string.
|
||||||
|
//
|
||||||
|
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
|
||||||
|
// not supported" warning whenever any key under punchy changes, because of
|
||||||
|
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
|
||||||
|
// punchy form. The tests filter by message rather than asserting total
|
||||||
|
// entry counts so that warning is tolerated without being locked into
|
||||||
|
// the format.
|
||||||
|
|
||||||
|
type capturedEntry struct {
|
||||||
|
Level slog.Level
|
||||||
|
Msg string
|
||||||
|
Attrs map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// capturingHandler is a slog.Handler that records each Record it receives so
|
||||||
|
// tests can assert on the level, message, and attribute map of individual log
|
||||||
|
// lines without coupling to any specific text format.
|
||||||
|
type capturingHandler struct {
|
||||||
|
entries []capturedEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
|
||||||
|
|
||||||
|
func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error {
|
||||||
|
e := capturedEntry{
|
||||||
|
Level: r.Level,
|
||||||
|
Msg: r.Message,
|
||||||
|
Attrs: make(map[string]any),
|
||||||
|
}
|
||||||
|
r.Attrs(func(a slog.Attr) bool {
|
||||||
|
e.Attrs[a.Key] = a.Value.Resolve().Any()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
h.entries = append(h.entries, e)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
|
||||||
|
func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h }
|
||||||
|
|
||||||
|
func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) {
|
||||||
|
t.Helper()
|
||||||
|
hook := &capturingHandler{}
|
||||||
|
return slog.New(hook), hook
|
||||||
|
}
|
||||||
|
|
||||||
|
func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry {
|
||||||
|
t.Helper()
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.Msg == msg {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatalf("no entry with message %q among %d entries", msg, len(entries))
|
||||||
|
return capturedEntry{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
|
||||||
|
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy enabled")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Empty(t, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||||
|
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy disabled")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Empty(t, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
|
||||||
|
assert.Equal(t, slog.LevelWarn, entry.Level)
|
||||||
|
assert.Empty(t, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy.respond changed")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Equal(t, map[string]any{"respond": true}, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy.delay changed")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy.respond_delay changed")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs)
|
||||||
|
}
|
||||||
|
|||||||
167
relay_manager.go
167
relay_manager.go
@@ -5,22 +5,22 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
type relayManager struct {
|
type relayManager struct {
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
hostmap *HostMap
|
hostmap *HostMap
|
||||||
amRelay atomic.Bool
|
amRelay atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager {
|
func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager {
|
||||||
rm := &relayManager{
|
rm := &relayManager{
|
||||||
l: l,
|
l: l,
|
||||||
hostmap: hostmap,
|
hostmap: hostmap,
|
||||||
@@ -29,7 +29,7 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c
|
|||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
err := rm.reload(c, false)
|
err := rm.reload(c, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to reload relay_manager")
|
rm.l.Error("Failed to reload relay_manager", "error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return rm
|
return rm
|
||||||
@@ -52,10 +52,10 @@ func (rm *relayManager) setAmRelay(v bool) {
|
|||||||
|
|
||||||
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
|
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
|
||||||
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
|
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
|
||||||
func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
|
func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
|
||||||
hm.Lock()
|
hm.Lock()
|
||||||
defer hm.Unlock()
|
defer hm.Unlock()
|
||||||
for i := 0; i < 32; i++ {
|
for range 32 {
|
||||||
index, err := generateIndex(l)
|
index, err := generateIndex(l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -92,24 +92,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
|
|||||||
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
|
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
|
||||||
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
|
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
|
||||||
if !ok {
|
if !ok {
|
||||||
fields := logrus.Fields{
|
var relayFrom, relayTo any
|
||||||
"relay": relayHostInfo.vpnAddrs[0],
|
|
||||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.RelayFromAddr == nil {
|
if m.RelayFromAddr == nil {
|
||||||
fields["relayFrom"] = m.OldRelayFromAddr
|
relayFrom = m.OldRelayFromAddr
|
||||||
} else {
|
} else {
|
||||||
fields["relayFrom"] = m.RelayFromAddr
|
relayFrom = m.RelayFromAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.RelayToAddr == nil {
|
if m.RelayToAddr == nil {
|
||||||
fields["relayTo"] = m.OldRelayToAddr
|
relayTo = m.OldRelayToAddr
|
||||||
} else {
|
} else {
|
||||||
fields["relayTo"] = m.RelayToAddr
|
relayTo = m.RelayToAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
rm.l.WithFields(fields).Info("relayManager failed to update relay")
|
rm.l.Info("relayManager failed to update relay",
|
||||||
|
"relay", relayHostInfo.vpnAddrs[0],
|
||||||
|
"initiatorRelayIndex", m.InitiatorRelayIndex,
|
||||||
|
"relayFrom", relayFrom,
|
||||||
|
"relayTo", relayTo,
|
||||||
|
)
|
||||||
return nil, fmt.Errorf("unknown relay")
|
return nil, fmt.Errorf("unknown relay")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
|
|||||||
msg := &NebulaControl{}
|
msg := &NebulaControl{}
|
||||||
err := msg.Unmarshal(d)
|
err := msg.Unmarshal(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
|
h.logger(f.l).Error("Failed to unmarshal control message", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,20 +147,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
|
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.Info("handleCreateRelayResponse",
|
||||||
"relayFrom": protoAddrToNetAddr(m.RelayFromAddr),
|
"relayFrom", protoAddrToNetAddr(m.RelayFromAddr),
|
||||||
"relayTo": protoAddrToNetAddr(m.RelayToAddr),
|
"relayTo", protoAddrToNetAddr(m.RelayToAddr),
|
||||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
"initiatorRelayIndex", m.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": m.ResponderRelayIndex,
|
"responderRelayIndex", m.ResponderRelayIndex,
|
||||||
"vpnAddrs": h.vpnAddrs}).
|
"vpnAddrs", h.vpnAddrs,
|
||||||
Info("handleCreateRelayResponse")
|
)
|
||||||
|
|
||||||
target := m.RelayToAddr
|
target := m.RelayToAddr
|
||||||
targetAddr := protoAddrToNetAddr(target)
|
targetAddr := protoAddrToNetAddr(target)
|
||||||
|
|
||||||
relay, err := rm.EstablishRelay(h, m)
|
relay, err := rm.EstablishRelay(h, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rm.l.WithError(err).Error("Failed to update relay for relayTo")
|
rm.l.Error("Failed to update relay for relayTo", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Do I need to complete the relays now?
|
// Do I need to complete the relays now?
|
||||||
@@ -170,12 +170,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
|
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
|
||||||
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
|
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
|
||||||
if peerHostInfo == nil {
|
if peerHostInfo == nil {
|
||||||
rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
|
rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
|
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
|
rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch peerRelay.State {
|
switch peerRelay.State {
|
||||||
@@ -193,12 +193,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
peer := peerHostInfo.vpnAddrs[0]
|
peer := peerHostInfo.vpnAddrs[0]
|
||||||
if !peer.Is4() {
|
if !peer.Is4() {
|
||||||
rm.l.WithField("relayFrom", peer).
|
rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address",
|
||||||
WithField("relayTo", target).
|
"relayFrom", peer,
|
||||||
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
|
"relayTo", target,
|
||||||
WithField("responderRelayIndex", resp.ResponderRelayIndex).
|
"initiatorRelayIndex", resp.InitiatorRelayIndex,
|
||||||
WithField("vpnAddrs", peerHostInfo.vpnAddrs).
|
"responderRelayIndex", resp.ResponderRelayIndex,
|
||||||
Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address")
|
"vpnAddrs", peerHostInfo.vpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,17 +214,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
|
|
||||||
msg, err := resp.Marshal()
|
msg, err := resp.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rm.l.WithError(err).
|
rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
|
||||||
Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.Info("send CreateRelayResponse",
|
||||||
"relayFrom": resp.RelayFromAddr,
|
"relayFrom", resp.RelayFromAddr,
|
||||||
"relayTo": resp.RelayToAddr,
|
"relayTo", resp.RelayToAddr,
|
||||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
"initiatorRelayIndex", resp.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
"responderRelayIndex", resp.ResponderRelayIndex,
|
||||||
"vpnAddrs": peerHostInfo.vpnAddrs}).
|
"vpnAddrs", peerHostInfo.vpnAddrs,
|
||||||
Info("send CreateRelayResponse")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -232,17 +232,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
from := protoAddrToNetAddr(m.RelayFromAddr)
|
from := protoAddrToNetAddr(m.RelayFromAddr)
|
||||||
target := protoAddrToNetAddr(m.RelayToAddr)
|
target := protoAddrToNetAddr(m.RelayToAddr)
|
||||||
|
|
||||||
logMsg := rm.l.WithFields(logrus.Fields{
|
logMsg := rm.l.With(
|
||||||
"relayFrom": from,
|
"relayFrom", from,
|
||||||
"relayTo": target,
|
"relayTo", target,
|
||||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
"initiatorRelayIndex", m.InitiatorRelayIndex,
|
||||||
"vpnAddrs": h.vpnAddrs})
|
"vpnAddrs", h.vpnAddrs,
|
||||||
|
)
|
||||||
|
|
||||||
logMsg.Info("handleCreateRelayRequest")
|
logMsg.Info("handleCreateRelayRequest")
|
||||||
// Is the source of the relay me? This should never happen, but did happen due to
|
// Is the source of the relay me? This should never happen, but did happen due to
|
||||||
// an issue migrating relays over to newly re-handshaked host info objects.
|
// an issue migrating relays over to newly re-handshaked host info objects.
|
||||||
if f.myVpnAddrsTable.Contains(from) {
|
if f.myVpnAddrsTable.Contains(from) {
|
||||||
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
logMsg.Error("Discarding relay request from myself", "myIP", from)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,37 +262,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
||||||
// We got a brand new Relay request, because its index is different than what we saw before.
|
// We got a brand new Relay request, because its index is different than what we saw before.
|
||||||
// This should never happen. The peer should never change an index, once created.
|
// This should never happen. The peer should never change an index, once created.
|
||||||
logMsg.WithFields(logrus.Fields{
|
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
|
||||||
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
|
"existingRemoteIndex", existingRelay.RemoteIndex)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case Disestablished:
|
case Disestablished:
|
||||||
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
||||||
// We got a brand new Relay request, because its index is different than what we saw before.
|
// We got a brand new Relay request, because its index is different than what we saw before.
|
||||||
// This should never happen. The peer should never change an index, once created.
|
// This should never happen. The peer should never change an index, once created.
|
||||||
logMsg.WithFields(logrus.Fields{
|
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
|
||||||
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
|
"existingRemoteIndex", existingRelay.RemoteIndex)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Mark the relay as 'Established' because it's safe to use again
|
// Mark the relay as 'Established' because it's safe to use again
|
||||||
h.relayState.UpdateRelayForByIpState(from, Established)
|
h.relayState.UpdateRelayForByIpState(from, Established)
|
||||||
case PeerRequested:
|
case PeerRequested:
|
||||||
// I should never be in this state, because I am terminal, not forwarding.
|
// I should never be in this state, because I am terminal, not forwarding.
|
||||||
logMsg.WithFields(logrus.Fields{
|
logMsg.Error("Unexpected Relay State found",
|
||||||
"existingRemoteIndex": existingRelay.RemoteIndex,
|
"existingRemoteIndex", existingRelay.RemoteIndex,
|
||||||
"state": existingRelay.State}).Error("Unexpected Relay State found")
|
"state", existingRelay.State)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
|
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.WithError(err).Error("Failed to add relay")
|
logMsg.Error("Failed to add relay", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
relay, ok := h.relayState.QueryRelayForByIp(from)
|
relay, ok := h.relayState.QueryRelayForByIp(from)
|
||||||
if !ok {
|
if !ok {
|
||||||
logMsg.WithField("from", from).Error("Relay State not found")
|
logMsg.Error("Relay State not found", "from", from)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -313,17 +314,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
|
|
||||||
msg, err := resp.Marshal()
|
msg, err := resp.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.
|
logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
|
||||||
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.Info("send CreateRelayResponse",
|
||||||
"relayFrom": from,
|
"relayFrom", from,
|
||||||
"relayTo": target,
|
"relayTo", target,
|
||||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
"initiatorRelayIndex", resp.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
"responderRelayIndex", resp.ResponderRelayIndex,
|
||||||
"vpnAddrs": h.vpnAddrs}).
|
"vpnAddrs", h.vpnAddrs,
|
||||||
Info("send CreateRelayResponse")
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
@@ -363,12 +363,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
|
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if !h.vpnAddrs[0].Is4() {
|
if !h.vpnAddrs[0].Is4() {
|
||||||
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
|
rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address",
|
||||||
WithField("relayTo", target).
|
"relayFrom", h.vpnAddrs[0],
|
||||||
WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
|
"relayTo", target,
|
||||||
WithField("responderRelayIndex", req.ResponderRelayIndex).
|
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||||
WithField("vpnAddr", target).
|
"responderRelayIndex", req.ResponderRelayIndex,
|
||||||
Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address")
|
"vpnAddr", target,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,17 +384,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.
|
logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err)
|
||||||
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.Info("send CreateRelayRequest",
|
||||||
"relayFrom": h.vpnAddrs[0],
|
"relayFrom", h.vpnAddrs[0],
|
||||||
"relayTo": target,
|
"relayTo", target,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": req.ResponderRelayIndex,
|
"responderRelayIndex", req.ResponderRelayIndex,
|
||||||
"vpnAddr": target}).
|
"vpnAddr", target,
|
||||||
Info("send CreateRelayRequest")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also track the half-created Relay state just received
|
// Also track the half-created Relay state just received
|
||||||
@@ -401,8 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
if !ok {
|
if !ok {
|
||||||
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
|
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.
|
logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err)
|
||||||
WithError(err).Error("relayManager Failed to allocate a local index for relay")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -10,8 +11,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// forEachFunc is used to benefit folks that want to do work inside the lock
|
// forEachFunc is used to benefit folks that want to do work inside the lock
|
||||||
@@ -66,11 +65,11 @@ type hostnamesResults struct {
|
|||||||
network string
|
network string
|
||||||
lookupTimeout time.Duration
|
lookupTimeout time.Duration
|
||||||
cancelFn func()
|
cancelFn func()
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
ips atomic.Pointer[map[netip.AddrPort]struct{}]
|
ips atomic.Pointer[map[netip.AddrPort]struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
|
func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
|
||||||
r := &hostnamesResults{
|
r := &hostnamesResults{
|
||||||
hostnames: make([]hostnamePort, len(hostPorts)),
|
hostnames: make([]hostnamePort, len(hostPorts)),
|
||||||
network: network,
|
network: network,
|
||||||
@@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
|
|||||||
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
|
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
|
l.Error("DNS resolution failed for static_map host",
|
||||||
|
"hostname", hostPort.name,
|
||||||
|
"network", r.network,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, a := range addrs {
|
for _, a := range addrs {
|
||||||
@@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if different {
|
if different {
|
||||||
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
|
l.Info("DNS results changed for host list",
|
||||||
|
"origSet", origSet,
|
||||||
|
"newSet", netipAddrs,
|
||||||
|
)
|
||||||
r.ips.Store(&netipAddrs)
|
r.ips.Store(&netipAddrs)
|
||||||
onUpdate()
|
onUpdate()
|
||||||
}
|
}
|
||||||
@@ -404,12 +410,7 @@ func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
|
|||||||
|
|
||||||
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
|
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
|
||||||
func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
|
func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
|
||||||
for _, v := range r.badRemotes {
|
return slices.Contains(r.badRemotes, remote)
|
||||||
if v == remote {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
|
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
|
||||||
|
|||||||
@@ -44,7 +44,10 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(control *nebula.Control) (*Service, error) {
|
func New(control *nebula.Control) (*Service, error) {
|
||||||
control.Start()
|
wait, err := control.Start()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
ctx := control.Context()
|
ctx := control.Context()
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Add the nebula wait function to the group so a fatal reader error
|
||||||
|
// propagates out through errgroup.Wait().
|
||||||
|
eg.Go(func() error {
|
||||||
|
return wait()
|
||||||
|
})
|
||||||
|
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"go.yaml.in/yaml/v3"
|
"go.yaml.in/yaml/v3"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
@@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := logrus.New()
|
logger := logging.NewLogger(os.Stdout)
|
||||||
logger.Out = os.Stdout
|
|
||||||
|
|
||||||
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
159
ssh.go
159
ssh.go
@@ -6,19 +6,21 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,12 +57,12 @@ type sshDeviceInfoFlags struct {
|
|||||||
Pretty bool
|
Pretty bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
|
func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) {
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshRun, err := configSSH(l, ssh, c)
|
sshRun, err := configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to reconfigure the sshd")
|
l.Error("Failed to reconfigure the sshd", "error", err)
|
||||||
ssh.Stop()
|
ssh.Stop()
|
||||||
}
|
}
|
||||||
if sshRun != nil {
|
if sshRun != nil {
|
||||||
@@ -76,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
|
|||||||
// updates the passed-in SSHServer. On success, it returns a function
|
// updates the passed-in SSHServer. On success, it returns a function
|
||||||
// that callers may invoke to run the configured ssh server. On
|
// that callers may invoke to run the configured ssh server. On
|
||||||
// failure, it returns nil, error.
|
// failure, it returns nil, error.
|
||||||
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
|
func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
|
||||||
listen := c.GetString("sshd.listen", "")
|
listen := c.GetString("sshd.listen", "")
|
||||||
if listen == "" {
|
if listen == "" {
|
||||||
return nil, fmt.Errorf("sshd.listen must be provided")
|
return nil, fmt.Errorf("sshd.listen must be provided")
|
||||||
@@ -118,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
for _, caAuthorizedKey := range rawCAs {
|
for _, caAuthorizedKey := range rawCAs {
|
||||||
err := ssh.AddTrustedCA(caAuthorizedKey)
|
err := ssh.AddTrustedCA(caAuthorizedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring")
|
l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -129,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
for _, rk := range keys {
|
for _, rk := range keys {
|
||||||
kDef, ok := rk.(map[string]any)
|
kDef, ok := rk.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
|
l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
user, ok := kDef["user"].(string)
|
user, ok := kDef["user"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field")
|
l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
case string:
|
case string:
|
||||||
err := ssh.AddAuthorizedKey(user, v)
|
err := ssh.AddAuthorizedKey(user, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key")
|
l.Warn("Failed to authorize key",
|
||||||
|
"error", err,
|
||||||
|
"sshKeyConfig", rk,
|
||||||
|
"sshKey", v,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
for _, subK := range v {
|
for _, subK := range v {
|
||||||
sk, ok := subK.(string)
|
sk, ok := subK.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key")
|
l.Warn("Did not understand ssh key",
|
||||||
|
"sshKeyConfig", rk,
|
||||||
|
"sshKey", subK,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := ssh.AddAuthorizedKey(user, sk)
|
err := ssh.AddAuthorizedKey(user, sk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key")
|
l.Warn("Failed to authorize key",
|
||||||
|
"error", err,
|
||||||
|
"sshKeyConfig", sk,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood")
|
l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -176,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
ssh.Stop()
|
ssh.Stop()
|
||||||
runner = func() {
|
runner = func() {
|
||||||
if err := ssh.Run(listen); err != nil {
|
if err := ssh.Run(listen); err != nil {
|
||||||
l.WithField("err", err).Warn("Failed to run the SSH server")
|
l.Warn("Failed to run the SSH server", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -186,7 +198,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
return runner, nil
|
return runner, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
|
func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
|
||||||
|
// sandboxDir defaults to a dir in temp. The intention is that end user will
|
||||||
|
// create this dir as needed. Overriding this config value to "" allows
|
||||||
|
// writing to anywhere in the system.
|
||||||
|
defaultDir := filepath.Join(os.TempDir(), "nebula-debug")
|
||||||
|
sandboxDir := c.GetString("sshd.sandbox_dir", defaultDir)
|
||||||
|
|
||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "list-hostmap",
|
Name: "list-hostmap",
|
||||||
ShortDescription: "List all known previously connected hosts",
|
ShortDescription: "List all known previously connected hosts",
|
||||||
@@ -245,7 +263,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "start-cpu-profile",
|
Name: "start-cpu-profile",
|
||||||
ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`",
|
ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`",
|
||||||
Callback: sshStartCpuProfile,
|
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||||
|
return sshStartCpuProfile(sandboxDir, fs, a, w)
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
@@ -260,7 +280,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "save-heap-profile",
|
Name: "save-heap-profile",
|
||||||
ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`",
|
ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`",
|
||||||
Callback: sshGetHeapProfile,
|
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||||
|
return sshGetHeapProfile(sandboxDir, fs, a, w)
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
@@ -272,7 +294,9 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "save-mutex-profile",
|
Name: "save-mutex-profile",
|
||||||
ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`",
|
ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`",
|
||||||
Callback: sshGetMutexProfile,
|
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||||
|
return sshGetMutexProfile(sandboxDir, fs, a, w)
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
@@ -505,13 +529,43 @@ func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
|
// sshSanitizeFilePath validates that the given file path is within the sandbox directory.
|
||||||
|
// If sandboxDir is empty, the path is returned as-is for backwards compatibility.
|
||||||
|
func sshSanitizeFilePath(sandboxDir, filePath string) (string, error) {
|
||||||
|
if sandboxDir == "" {
|
||||||
|
return filePath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean and resolve the path relative to the sandbox directory
|
||||||
|
if !filepath.IsAbs(filePath) {
|
||||||
|
filePath = filepath.Join(sandboxDir, filePath)
|
||||||
|
}
|
||||||
|
cleaned := filepath.Clean(filePath)
|
||||||
|
|
||||||
|
// Ensure the resolved path is within the sandbox directory
|
||||||
|
cleanedSandbox := filepath.Clean(sandboxDir)
|
||||||
|
if cleaned == cleanedSandbox {
|
||||||
|
return "", fmt.Errorf("path %q resolves to the sandbox directory itself %q", filePath, sandboxDir)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(cleaned, cleanedSandbox+string(filepath.Separator)) {
|
||||||
|
return "", fmt.Errorf("path %q is outside the sandbox directory %q", filePath, sandboxDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleaned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshStartCpuProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
err := w.WriteLine("No path to write profile provided")
|
err := w.WriteLine("No path to write profile provided")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := os.Create(a[0])
|
filePath, err := sshSanitizeFilePath(sandboxDir, a[0])
|
||||||
|
if err != nil {
|
||||||
|
return w.WriteLine(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Create(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
|
err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
|
||||||
return err
|
return err
|
||||||
@@ -675,12 +729,17 @@ func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
|
|||||||
return w.WriteLine("Changed")
|
return w.WriteLine("Changed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error {
|
func sshGetHeapProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine("No path to write profile provided")
|
return w.WriteLine("No path to write profile provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := os.Create(a[0])
|
filePath, err := sshSanitizeFilePath(sandboxDir, a[0])
|
||||||
|
if err != nil {
|
||||||
|
return w.WriteLine(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Create(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
|
err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
|
||||||
return err
|
return err
|
||||||
@@ -711,12 +770,17 @@ func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error {
|
|||||||
return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
|
return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
|
func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine("No path to write profile provided")
|
return w.WriteLine("No path to write profile provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := os.Create(a[0])
|
filePath, err := sshSanitizeFilePath(sandboxDir, a[0])
|
||||||
|
if err != nil {
|
||||||
|
return w.WriteLine(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Create(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
|
return w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
|
||||||
}
|
}
|
||||||
@@ -735,36 +799,45 @@ func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
|
|||||||
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
|
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||||
|
ctrl, ok := l.Handler().(interface {
|
||||||
|
GetLevel() slog.Level
|
||||||
|
SetLevel(slog.Level)
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return w.WriteLine("Log level is not reconfigurable on this logger")
|
||||||
|
}
|
||||||
|
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
|
||||||
}
|
}
|
||||||
|
|
||||||
level, err := logrus.ParseLevel(a[0])
|
level, err := logging.ParseLevel(strings.ToLower(a[0]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels))
|
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a))
|
||||||
}
|
}
|
||||||
|
|
||||||
l.SetLevel(level)
|
ctrl.SetLevel(level)
|
||||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||||
|
ctrl, ok := l.Handler().(interface {
|
||||||
|
GetFormat() string
|
||||||
|
SetFormat(string) error
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return w.WriteLine("Log format is not reconfigurable on this logger")
|
||||||
|
}
|
||||||
|
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
|
||||||
}
|
}
|
||||||
|
|
||||||
logFormat := strings.ToLower(a[0])
|
if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil {
|
||||||
switch logFormat {
|
return err
|
||||||
case "text":
|
|
||||||
l.Formatter = &logrus.TextFormatter{}
|
|
||||||
case "json":
|
|
||||||
l.Formatter = &logrus.JSONFormatter{}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
|
||||||
}
|
}
|
||||||
|
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
|
||||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||||
@@ -831,9 +904,7 @@ func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
|
|||||||
|
|
||||||
relays := map[uint32]*HostInfo{}
|
relays := map[uint32]*HostInfo{}
|
||||||
ifce.hostMap.Lock()
|
ifce.hostMap.Lock()
|
||||||
for k, v := range ifce.hostMap.Relays {
|
maps.Copy(relays, ifce.hostMap.Relays)
|
||||||
relays[k] = v
|
|
||||||
}
|
|
||||||
ifce.hostMap.Unlock()
|
ifce.hostMap.Unlock()
|
||||||
|
|
||||||
type RelayFor struct {
|
type RelayFor struct {
|
||||||
|
|||||||
108
sshd/server.go
108
sshd/server.go
@@ -2,19 +2,19 @@ package sshd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/armon/go-radix"
|
"github.com/armon/go-radix"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SSHServer struct {
|
type SSHServer struct {
|
||||||
config *ssh.ServerConfig
|
config *ssh.ServerConfig
|
||||||
l *logrus.Entry
|
l *slog.Logger
|
||||||
|
|
||||||
certChecker *ssh.CertChecker
|
certChecker *ssh.CertChecker
|
||||||
|
|
||||||
@@ -27,20 +27,21 @@ type SSHServer struct {
|
|||||||
commands *radix.Tree
|
commands *radix.Tree
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
|
|
||||||
// Locks the conns/counter to avoid concurrent map access
|
// Call the cancel() function to stop all active sessions
|
||||||
connsLock sync.Mutex
|
ctx context.Context
|
||||||
conns map[int]*session
|
cancel func()
|
||||||
counter int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
|
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
|
||||||
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
|
func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
s := &SSHServer{
|
s := &SSHServer{
|
||||||
trustedKeys: make(map[string]map[string]bool),
|
trustedKeys: make(map[string]map[string]bool),
|
||||||
l: l,
|
l: l,
|
||||||
commands: radix.New(),
|
commands: radix.New(),
|
||||||
conns: make(map[int]*session),
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
cc := ssh.CertChecker{
|
cc := ssh.CertChecker{
|
||||||
@@ -120,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.trustedCAs = append(s.trustedCAs, pk)
|
s.trustedCAs = append(s.trustedCAs, pk)
|
||||||
s.l.WithField("sshKey", pubKey).Info("Trusted CA key")
|
s.l.Info("Trusted CA key", "sshKey", pubKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tk[string(pk.Marshal())] = true
|
tk[string(pk.Marshal())] = true
|
||||||
s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
|
s.l.Info("Authorized ssh key",
|
||||||
|
"sshKey", pubKey,
|
||||||
|
"sshUser", user,
|
||||||
|
)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,7 +159,7 @@ func (s *SSHServer) Run(addr string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.l.WithField("sshListener", addr).Info("SSH server is listening")
|
s.l.Info("SSH server is listening", "sshListener", addr)
|
||||||
|
|
||||||
// Run loops until there is an error
|
// Run loops until there is an error
|
||||||
s.run()
|
s.run()
|
||||||
@@ -171,48 +175,54 @@ func (s *SSHServer) run() {
|
|||||||
c, err := s.listener.Accept()
|
c, err := s.listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, net.ErrClosed) {
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
s.l.WithError(err).Warn("Error in listener, shutting down")
|
s.l.Warn("Error in listener, shutting down", "error", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
go func(c net.Conn) {
|
||||||
conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
|
// NewServerConn may block while waiting for the client to complete the handshake.
|
||||||
fp := ""
|
// Ensure that a bad client doesn't hurt us by checking for the parent context
|
||||||
if conn != nil {
|
// cancellation before calling NewServerConn, and forcing the socket to close when
|
||||||
fp = conn.Permissions.Extensions["fp"]
|
// the context is cancelled.
|
||||||
}
|
sessionContext, sessionCancel := context.WithCancel(s.ctx)
|
||||||
|
go func() {
|
||||||
if err != nil {
|
<-sessionContext.Done()
|
||||||
l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
|
c.Close()
|
||||||
|
}()
|
||||||
|
conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
|
||||||
|
fp := ""
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
l = l.WithField("sshUser", conn.User())
|
fp = conn.Permissions.Extensions["fp"]
|
||||||
conn.Close()
|
|
||||||
}
|
}
|
||||||
if fp != "" {
|
|
||||||
l = l.WithField("sshFingerprint", fp)
|
if err != nil {
|
||||||
|
l := s.l.With(
|
||||||
|
"error", err,
|
||||||
|
"remoteAddress", c.RemoteAddr(),
|
||||||
|
)
|
||||||
|
if conn != nil {
|
||||||
|
l = l.With("sshUser", conn.User())
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
if fp != "" {
|
||||||
|
l = l.With("sshFingerprint", fp)
|
||||||
|
}
|
||||||
|
l.Warn("failed to handshake")
|
||||||
|
sessionCancel()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
l.Warn("failed to handshake")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
l := s.l.WithField("sshUser", conn.User())
|
l := s.l.With("sshUser", conn.User())
|
||||||
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
|
l.Info("ssh user logged in",
|
||||||
|
"remoteAddress", c.RemoteAddr(),
|
||||||
|
"sshFingerprint", fp,
|
||||||
|
)
|
||||||
|
|
||||||
session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
|
NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session"))
|
||||||
s.connsLock.Lock()
|
|
||||||
s.counter++
|
|
||||||
counter := s.counter
|
|
||||||
s.conns[counter] = session
|
|
||||||
s.connsLock.Unlock()
|
|
||||||
|
|
||||||
go ssh.DiscardRequests(reqs)
|
go ssh.DiscardRequests(reqs)
|
||||||
go func() {
|
|
||||||
<-session.exitChan
|
}(c)
|
||||||
s.l.WithField("id", counter).Debug("closing conn")
|
|
||||||
s.connsLock.Lock()
|
|
||||||
delete(s.conns, counter)
|
|
||||||
s.connsLock.Unlock()
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,15 +230,11 @@ func (s *SSHServer) Stop() {
|
|||||||
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
|
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
|
||||||
if s.listener != nil {
|
if s.listener != nil {
|
||||||
if err := s.listener.Close(); err != nil {
|
if err := s.listener.Close(); err != nil {
|
||||||
s.l.WithError(err).Warn("Failed to close the sshd listener")
|
s.l.Warn("Failed to close the sshd listener", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHServer) closeSessions() {
|
func (s *SSHServer) closeSessions() {
|
||||||
s.connsLock.Lock()
|
s.cancel()
|
||||||
for _, c := range s.conns {
|
|
||||||
c.Close()
|
|
||||||
}
|
|
||||||
s.connsLock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user